From 80995e42fdd6ef59d1d86e66a8595d4b7e410cad Mon Sep 17 00:00:00 2001 From: cherishPre <2649458107@qq.com> Date: Wed, 20 Dec 2023 11:39:22 +0800 Subject: [PATCH 1/2] TETFN and CENET models added --- README.md | 6 +- results/result-stat.md | 10 +- setup.cfg | 1 + src/MMSA/__main__.py | 4 +- src/MMSA/config/config_regression.json | 243 ++++++++++++ src/MMSA/config/config_tune.json | 127 +++++++ src/MMSA/models/AMIO.py | 11 +- src/MMSA/models/multiTask/TETFN.py | 259 +++++++++++++ src/MMSA/models/multiTask/__init__.py | 1 + src/MMSA/models/singleTask/CENET.py | 487 +++++++++++++++++++++++++ src/MMSA/models/singleTask/__init__.py | 3 +- src/MMSA/models/subNets/AlignNets.py | 24 +- src/MMSA/trains/ATIO.py | 2 + src/MMSA/trains/multiTask/MLF_DNN.py | 4 +- src/MMSA/trains/multiTask/MLMF.py | 3 - src/MMSA/trains/multiTask/MTFN.py | 3 - src/MMSA/trains/multiTask/SELF_MM.py | 4 +- src/MMSA/trains/multiTask/TETFN.py | 332 +++++++++++++++++ src/MMSA/trains/multiTask/__init__.py | 1 + src/MMSA/trains/singleTask/CENET.py | 125 +++++++ src/MMSA/trains/singleTask/__init__.py | 3 +- 21 files changed, 1626 insertions(+), 27 deletions(-) create mode 100644 src/MMSA/models/multiTask/TETFN.py create mode 100644 src/MMSA/models/singleTask/CENET.py create mode 100644 src/MMSA/trains/multiTask/TETFN.py create mode 100644 src/MMSA/trains/singleTask/CENET.py diff --git a/README.md b/README.md index 77ec560..f90d180 100644 --- a/README.md +++ b/README.md @@ -145,11 +145,13 @@ MMSA uses feature files that are organized as follows: | Multi-Task | [MLF_DNN](src/MMSA/models/multiTask/MLF_DNN.py) | [MMSA](https://github.com/thuiar/MMSA) | ACL 2020 | | Multi-Task | [MTFN](src/MMSA/models/multiTask/MTFN.py) | [MMSA](https://github.com/thuiar/MMSA) | ACL 2020 | | Multi-Task | [MLMF](src/MMSA/models/multiTask/MLMF.py) | [MMSA](https://github.com/thuiar/MMSA) | ACL 2020 | +| Multi-Task | [SELF_MM](src/MMSA/models/multiTask/SELF_MM.py) | [Self-MM](https://github.com/thuiar/Self-MM) | AAAI 2021 +| Multi-Task | [TETFN](src/MMSA/models/multiTask/TETFN.py) | TETFN | PR 2023 | Single-Task | [BERT-MAG](src/MMSA/models/singleTask/BERT_MAG.py) | [MAG-BERT](https://github.com/WasifurRahman/BERT_multimodal_transformer) | ACL 2020 | -| Single-Task | [MISA](src/MMSA/models/singleTask/MISA.py) | [MISA](https://github.com/declare-lab/MISA) | ACMMM 2020 | -| Single-Task | [SELF_MM](src/MMSA/models/multiTask/SELF_MM.py) | [Self-MM](https://github.com/thuiar/Self-MM) | AAAI 2021 | +| Single-Task | [MISA](src/MMSA/models/singleTask/MISA.py) | [MISA](https://github.com/declare-lab/MISA) | ACMMM 2020 | | | Single-Task | [MMIM](src/MMSA/models/singleTask/MMIM.py) | [MMIM](https://github.com/declare-lab/Multimodal-Infomax) | EMNLP 2021 | | Single-Task | BBFN (Work in Progress) | [BBFN](https://github.com/declare-lab/BBFN) | ICMI 2021 | +| Single-Task | [CENET](src/MMSA/models/singleTask/CENET.py) | [CENET](https://github.com/Say2L/CENet) | TMM 2022 | ## 4. Results diff --git a/results/result-stat.md b/results/result-stat.md index e2219c9..87648f1 100644 --- a/results/result-stat.md +++ b/results/result-stat.md @@ -13,20 +13,24 @@ | mult |79.71 |79.63 |80.98 |80.95 |42.68 |36.91 |87.99 |70.22 | Unaligned | | misa |81.84 |81.82 |83.54 |83.58 |47.08 |41.37 |77.65 |77.81 | Unaligned | | self_mm |83.44 |83.36 |85.46 |85.43 |53.47 |46.67 |70.80 |79.63 | Unaligned | +| tetfn |83.24 |83.13 |85.37 |85.33 |53.64 |45.77 |70.84 |79.84 | Aligned | +| cenet |83.53 |83.49 |85.21 |85.22 |50.87 |44.90 |72.54 |79.53 | Unaligned | - MOSEI | Model |Has0_acc_2 |Has0_F1_score |Non0_acc_2 |Non0_F1_score |Mult_acc_5 |Mult_acc_7 |MAE |Corr | Data Setting| | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | | ef_lstm |77.84 |78.34 |80.79 |80.67 |51.16 |50.01 |60.05 |68.25 | Aligned | -| lf_dnn |80.6 |80.85 |82.74 |82.52 |51.97 |50.83 |58.02 |70.87 | Unaligned | -| tfn |78.50 |78.96 |81.89 |81.74 |53.1 |51.6 |57.26 |71.41 | Unaligned | +| lf_dnn |80.60 |80.85 |82.74 |82.52 |51.97 |50.83 |58.02 |70.87 | Unaligned | +| tfn |78.50 |78.96 |81.89 |81.74 |53.10 |51.60 |57.26 |71.41 | Unaligned | | lmf |80.54 |80.94 |83.48 |83.36 |52.99 |51.59 |57.57 |71.69 | Unaligned | | mfn |78.94 |79.55 |82.86 |82.85 |52.76 |51.34 |57.33 |71.82 | Aligned | | graph_mfn |81.28 |81.48 |83.48 |83.23 |52.69 |51.37 |57.45 |71.33 | Aligned | | mult |81.15 |81.56 |84.63 |84.52 |54.18 |52.84 |55.93 |73.31 | Unaligned | | misa |80.67 |81.12 |84.67 |84.66 |53.63 |52.05 |55.75 |75.15 | Unaligned | | self_mm |83.76 |83.82 |85.15 |84.90 |55.53 |53.87 |53.09 |76.49 | Unaligned | +| tetfn |84.12 |84.35 |86.21 |86.11 |55.78 |53.90 |53.73 |76.96 | Aligned | +| cenet |83.52 |83.85 |86.38 |86.32 |56.15 |54.26 |52.59 |77.75 | Unaligned | - SIMS @@ -43,6 +47,8 @@ | mtfn |81.09 |68.80 |40.31 |81.01 |39.54 |66.58 | Unaligned | | mlmf |79.34 |68.36 |41.05 |79.07 |40.91 |63.90 | Unaligned | | self_mm |80.04 |65.47 |41.53 |80.44 |42.50 |59.52 | Unaligned | +| tetfn |81.18 |63.24 |41.79 |80.24 |42.00 |57.65 | Unaligned | +| cenet |77.90 |62.58 |33.92 |77.53 |47.09 |53.95 | Unaligned | ## Classification > Data setting is the same as `Regression` diff --git a/setup.cfg b/setup.cfg index fba1b63..552e069 100644 --- a/setup.cfg +++ b/setup.cfg @@ -34,6 +34,7 @@ install_requires = nvidia-ml-py3 >= 7.352.0 scikit-learn >= 0.24.2 easydict >= 1.9 + pytorch_transformers >= 1.2.0 [options.packages.find] where = src \ No newline at end of file diff --git a/src/MMSA/__main__.py b/src/MMSA/__main__.py index 7a7d984..60c6faf 100644 --- a/src/MMSA/__main__.py +++ b/src/MMSA/__main__.py @@ -6,8 +6,8 @@ def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('-m', '--model', type=str, default='lf_dnn', help='Name of model', - choices=['lf_dnn', 'ef_lstm', 'tfn', 'lmf', 'mfn', 'graph_mfn', 'mult', 'bert_mag', - 'misa', 'mfm', 'mlf_dnn', 'mtfn', 'mlmf', 'self_mm', 'mmim']) + choices=['lf_dnn', 'ef_lstm', 'tfn', 'mctn','lmf', 'mfn', 'graph_mfn', 'mult', 'bert_mag', + 'misa', 'mfm', 'mlf_dnn', 'mtfn', 'mlmf', 'self_mm', 'mmim','tfr_net','tetfn','cenet']) parser.add_argument('-d', '--dataset', type=str, default='sims', choices=['sims', 'mosi', 'mosei'], help='Name of dataset') parser.add_argument('-c', '--config', type=str, default='', diff --git a/src/MMSA/config/config_regression.json b/src/MMSA/config/config_regression.json index d1b9711..a29b8c8 100644 --- a/src/MMSA/config/config_regression.json +++ b/src/MMSA/config/config_regression.json @@ -61,6 +61,19 @@ "missing_rate": [0.2, 0.2, 0.2], "missing_seed": [1111, 1111, 1111] } + }, + "simsv2": { + "unaligned": { + "featurePath": "SIMSv2/Processed/sims_unaligned.pkl", + "seq_lens": [39, 400, 55], + "feature_dims": [768, 33, 709], + "train_samples": 2722, + "num_classes": 3, + "language": "cn", + "KeyEval": "Loss", + "missing_rate": [0.2, 0.2, 0.2], + "missing_seed": [1111, 1111, 1111] + } } }, "tfn": { @@ -1171,5 +1184,235 @@ "num_temporal_head": 25 } } + }, + "tetfn": { + "commonParams": { + "need_data_aligned": true, + "need_model_aligned": true, + "need_normalized": false, + "use_bert": true, + "use_finetune": true, + "save_labels": false, + "early_stop": 8, + "update_epochs": 4 + }, + "datasetParams": { + "mosi": { + "batch_size": 64, + "transformers": "bert", + "pretrained": "bert-base-uncased", + "a_lstm_hidden_size":16, + "a_lstm_layers":1, + "a_lstm_dropout":0.0, + "v_lstm_hidden_size":64, + "v_lstm_layers":1, + "v_lstm_dropout":0.0, + "conv1d_kernel_size_l":1, + "conv1d_kernel_size_a":1, + "conv1d_kernel_size_v":3, + "dst_feature_dims":50, + "nheads":5, + "attn_dropout":0.1, + "attn_dropout_a":0.0, + "attn_dropout_v":0.1, + "relu_dropout":0.0, + "embed_dropout":0.0, + "res_dropout":0.1, + "post_fusion_dropout":0.0, + "post_fusion_dim":64, + "post_text_dropout":0.0, + "post_text_dim":32, + "post_audio_dropout":0.0, + "post_audio_dim":32, + "post_video_dropout":0.0, + "post_video_dim":16, + "train_samples":1284, + "excludeZero":true, + "update_epochs":4, + "H":3, + "decay":false, + "weight_decay_bert":0.001, + "learning_rate_bert":3e-5, + "weight_decay_audio":0.01, + "weight_decay_video":0.0, + "weight_decay_other":0.01, + "learning_rate_audio":0.0005, + "learning_rate_video":0.0003, + "learning_rate_other":0.0003 + }, + "mosei": { + "batch_size": 32, + "transformers": "bert", + "pretrained": "bert-base-uncased", + "a_lstm_hidden_size":32, + "a_lstm_layers":1, + "a_lstm_dropout":0.0, + "v_lstm_hidden_size":32, + "v_lstm_layers":1, + "v_lstm_dropout":0.0, + "conv1d_kernel_size_l":1, + "conv1d_kernel_size_a":1, + "conv1d_kernel_size_v":1, + "dst_feature_dims":50, + "nheads":5, + "attn_dropout":0.1, + "attn_dropout_a":0.0, + "attn_dropout_v":0.1, + "relu_dropout":0.1, + "embed_dropout":0.1, + "res_dropout":0.0, + "post_fusion_dropout":0.0, + "post_fusion_dim":64, + "post_text_dropout":0.1, + "post_text_dim":64, + "post_audio_dropout":0.0, + "post_audio_dim":32, + "post_video_dropout":0.1, + "post_video_dim":16, + "train_samples":16326, + "excludeZero":true, + "update_epochs":4, + "H":3, + "decay":false, + "weight_decay_bert":0.001, + "learning_rate_bert":3e-5, + "weight_decay_audio":0.0, + "weight_decay_video":0.001, + "weight_decay_other":0.01, + "learning_rate_audio":0.001, + "learning_rate_video":0.005, + "learning_rate_other":0.0001 + }, + "sims": { + "batch_size": 64, + "transformers": "bert", + "pretrained": "bert-base-chinese", + "a_lstm_hidden_size":16, + "a_lstm_layers":1, + "a_lstm_dropout":0.0, + "v_lstm_hidden_size":64, + "v_lstm_layers":1, + "v_lstm_dropout":0.0, + "conv1d_kernel_size_l":3, + "conv1d_kernel_size_a":5, + "conv1d_kernel_size_v":1, + "dst_feature_dims":50, + "nheads":5, + "attn_dropout":0.0, + "attn_dropout_a":0.1, + "attn_dropout_v":0.0, + "relu_dropout":0.1, + "embed_dropout":0.0, + "res_dropout":0.0, + "post_fusion_dropout":0.1, + "post_fusion_dim":64, + "post_text_dropout":0.0, + "post_text_dim":64, + "post_audio_dropout":0.0, + "post_audio_dim":16, + "post_video_dropout":0.0, + "post_video_dim":32, + "train_samples":1368, + "excludeZero":true, + "update_epochs":4, + "H":3, + "decay":false, + "weight_decay_bert":0.001, + "learning_rate_bert":3e-5, + "weight_decay_audio":0.0, + "weight_decay_video":0.01, + "weight_decay_other":0.001, + "learning_rate_audio":0.003, + "learning_rate_video":0.0005, + "learning_rate_other":0.003 + }, + "simsv2": { + "batch_size": 64, + "transformers": "bert", + "pretrained": "bert-base-chinese", + "a_lstm_hidden_size":16, + "a_lstm_layers":1, + "a_lstm_dropout":0.0, + "v_lstm_hidden_size":32, + "v_lstm_layers":1, + "v_lstm_dropout":0.0, + "conv1d_kernel_size_l":5, + "conv1d_kernel_size_a":3, + "conv1d_kernel_size_v":1, + "dst_feature_dims":50, + "nheads":5, + "attn_dropout":0.1, + "attn_dropout_a":0.0, + "attn_dropout_v":0.1, + "relu_dropout":0.0, + "embed_dropout":0.0, + "res_dropout":0.1, + "post_fusion_dropout":0.1, + "post_fusion_dim":128, + "post_text_dropout":0.1, + "post_text_dim":64, + "post_audio_dropout":0.0, + "post_audio_dim":16, + "post_video_dropout":0.0, + "post_video_dim":32, + "train_samples":2722, + "excludeZero":true, + "update_epochs":4, + "H":3, + "decay":false, + "weight_decay_bert":0.001, + "learning_rate_bert":5e-6, + "weight_decay_audio":0.001, + "weight_decay_video":0.001, + "weight_decay_other":0.001, + "learning_rate_audio":0.0005, + "learning_rate_video":0.005, + "learning_rate_other":0.0003 + } + } + }, + "cenet":{ + "commonParams": { + "need_data_aligned": false, + "need_model_aligned": false, + "need_normalized": false, + "use_bert": true, + "use_finetune": true, + "early_stop": 8 + }, + "datasetParams": { + "mosi": { + "pretrained": "bert-base-uncased", + "learning_rate":1e-5, + "weight_decay":0.0001, + "max_grad_norm":2, + "adam_epsilon":3e-8, + "batch_size":64 + }, + "mosei":{ + "pretrained": "bert-base-uncased", + "learning_rate":1e-5, + "weight_decay":0.0001, + "max_grad_norm":2, + "adam_epsilon":1e-8, + "batch_size":64 + }, + "sims":{ + "pretrained": "bert-base-chinese", + "learning_rate":2e-6, + "weight_decay":0.0, + "max_grad_norm":2, + "adam_epsilon":2e-8, + "batch_size":32 + }, + "simsv2":{ + "pretrained": "bert-base-chinese", + "learning_rate":3e-5, + "weight_decay":0.0, + "max_grad_norm":2, + "adam_epsilon":3e-8, + "batch_size":64 + } + } } } diff --git a/src/MMSA/config/config_tune.json b/src/MMSA/config/config_tune.json index 06f652e..4fe6a04 100644 --- a/src/MMSA/config/config_tune.json +++ b/src/MMSA/config/config_tune.json @@ -51,6 +51,19 @@ "language": "cn", "KeyEval": "Loss" } + }, + "simsv2": { + "unaligned": { + "featurePath": "SIMSv2/Processed/sims_unaligned.pkl", + "seq_lens": [39, 400, 55], + "feature_dims": [768, 33, 709], + "train_samples": 2722, + "num_classes": 3, + "language": "cn", + "KeyEval": "Loss", + "missing_rate": [0.2, 0.2, 0.2], + "missing_seed": [1111, 1111, 1111] + } } }, "tfn": { @@ -944,5 +957,119 @@ [4e-6, 1e-6, 2e-6] ] } + }, + "tetfn": { + "commonParams": { + "need_data_aligned": true, + "need_model_aligned": true, + "need_normalized": false, + "use_bert": true, + "use_finetune": true, + "save_labels": false, + "early_stop": 8, + "update_epochs": 4, + "excludeZero":true, + "decay":false, + "nheads":5, + "dst_feature_dims":50, + "transformers": "bert", + "pretrained": "bert-base-uncased" + }, + "debugParams": { + "d_paras": [ + "batch_size", + "learning_rate_bert", + "learning_rate_audio", + "learning_rate_video", + "learning_rate_other", + "weight_decay_bert", + "weight_decay_other", + "weight_decay_audio", + "weight_decay_video", + "a_lstm_hidden_size", + "v_lstm_hidden_size", + "a_lstm_dropout", + "v_lstm_dropout", + "a_lstm_layers", + "v_lstm_layers", + "post_fusion_dim", + "post_text_dim", + "post_audio_dim", + "post_video_dim", + "conv1d_kernel_size_l", + "conv1d_kernel_size_a", + "conv1d_kernel_size_v", + "post_fusion_dropout", + "post_text_dropout", + "post_audio_dropout", + "post_video_dropout", + "attn_dropout", + "attn_dropout_a", + "attn_dropout_v", + "relu_dropout", + "embed_dropout", + "res_dropout", + "H" + ], + "batch_size": [32,64,128], + "learning_rate_bert": [5e-6,1e-5,3e-5,5e-5], + "learning_rate_audio": [0.0001, 0.0003, 0.0005, 0.001, 0.003, 0.005], + "learning_rate_video": [0.0001, 0.0003, 0.0005, 0.001, 0.003, 0.005], + "learning_rate_other": [0.0001, 0.0003, 0.0005, 0.001, 0.003, 0.005], + "weight_decay_bert": [0.001, 0.01], + "weight_decay_audio": [0.0, 0.001, 0.01], + "weight_decay_video": [0.0, 0.001, 0.01], + "weight_decay_other": [0.001, 0.01], + "a_lstm_hidden_size": [16, 32], + "v_lstm_hidden_size": [32, 64], + "a_lstm_layers": 1, + "v_lstm_layers": 1, + "a_lstm_dropout": [0.0], + "v_lstm_dropout": [0.0], + "post_fusion_dim": [64, 128], + "post_text_dim": [32, 64], + "post_audio_dim": [16, 32], + "post_video_dim": [16, 32], + "post_fusion_dropout": [0.1, 0.0], + "post_text_dropout": [0.1, 0.0], + "post_audio_dropout": [0.1, 0.0], + "post_video_dropout": [0.1, 0.0], + "attn_dropout": [0.1, 0.0], + "attn_dropout_a": [0.1, 0.0], + "attn_dropout_v": [0.1, 0.0], + "relu_dropout": [0.1, 0.0], + "embed_dropout": [0.1, 0.0], + "res_dropout": [0.1, 0.0], + "conv1d_kernel_size_l": [1, 3, 5], + "conv1d_kernel_size_a": [1, 3, 5], + "conv1d_kernel_size_v": [1, 3, 5], + "H": [3.0] + } + }, + "cenet": { + "commonParams": { + "need_data_aligned": false, + "need_model_aligned": false, + "need_normalized": false, + "use_bert": true, + "use_finetune": true, + "early_stop": 8, + "transformers": "bert", + "pretrained": "bert-base-uncased" + }, + "debugParams": { + "d_paras": [ + "learning_rate", + "weight_decay", + "max_grad_norm", + "adam_epsilon", + "batch_size" + ], + "learning_rate":[1e-6,2e-6,5e-6,1e-5,3e-5,5e-5], + "weight_decay":[0.0, 0.0001], + "max_grad_norm":2, + "adam_epsilon":[1e-8,2e-8,3e-8], + "batch_size":[32,64,128] + } } } diff --git a/src/MMSA/models/AMIO.py b/src/MMSA/models/AMIO.py index 37c6393..90aea91 100644 --- a/src/MMSA/models/AMIO.py +++ b/src/MMSA/models/AMIO.py @@ -7,7 +7,7 @@ from .singleTask import * from .missingTask import * from .subNets import AlignSubNet - +from pytorch_transformers import BertConfig class AMIO(nn.Module): def __init__(self, args): @@ -26,11 +26,13 @@ def __init__(self, args): 'misa': MISA, 'mfm': MFM, 'mmim': MMIM, + 'cenet': CENET, # multi-task 'mtfn': MTFN, 'mlmf': MLMF, 'mlf_dnn': MLF_DNN, 'self_mm': SELF_MM, + 'tetfn': TETFN, # missing-task 'tfr_net': TFR_NET } @@ -41,7 +43,12 @@ def __init__(self, args): if 'seq_lens' in args.keys(): args['seq_lens'] = self.alignNet.get_seq_len() lastModel = self.MODEL_MAP[args['model_name']] - self.Model = lastModel(args) + + if args.model_name == 'cenet': + config = BertConfig.from_pretrained(args.pretrained, num_labels=1, finetuning_task='sst') + self.Model = CENET.from_pretrained(args.pretrained, config=config, pos_tag_embedding=True, senti_embedding=True, polarity_embedding=True, args=args) + else: + self.Model = lastModel(args) def forward(self, text_x, audio_x, video_x, *args, **kwargs): if(self.need_model_aligned): diff --git a/src/MMSA/models/multiTask/TETFN.py b/src/MMSA/models/multiTask/TETFN.py new file mode 100644 index 0000000..2f79534 --- /dev/null +++ b/src/MMSA/models/multiTask/TETFN.py @@ -0,0 +1,259 @@ +""" +Paper: TETFN: A text enhanced transformer fusion network for multimodal sentiment analysis +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from ..subNets.transformers_encoder.transformer import TransformerEncoder +from ..subNets import BertTextEncoder + +__all__ = ['TETFN'] + +class TETFN(nn.Module): + def __init__(self, args): + super(TETFN, self).__init__() + # text subnets + self.args = args + self.aligned = args.need_data_aligned + self.text_model = BertTextEncoder(use_finetune=args.use_finetune, transformers=args.transformers, pretrained=args.pretrained) + + # audio-vision subnets + text_in, audio_in, video_in = args.feature_dims + + self.audio_model = AuViSubNet( + audio_in, + args.a_lstm_hidden_size, + args.conv1d_kernel_size_a, + args.dst_feature_dims, + num_layers=args.a_lstm_layers, dropout=args.a_lstm_dropout) + + self.video_model = AuViSubNet( + video_in, + args.v_lstm_hidden_size, + args.conv1d_kernel_size_a, + args.dst_feature_dims, + num_layers=args.v_lstm_layers, dropout=args.v_lstm_dropout) + + self.proj_l = nn.Conv1d(text_in, args.dst_feature_dims, kernel_size=args.conv1d_kernel_size_l, padding=0, bias=False) + # fusion subnets + self.trans_l_with_a = self.get_network(self_type='la') + self.trans_l_with_v = self.get_network(self_type='lv') + self.trans_a_with_l = self.get_network(self_type='al') + + self.trans_a_with_v = TextEnhancedTransformer( + embed_dim=args.dst_feature_dims, + num_heads=args.nheads, + layers=2, attn_dropout=args.attn_dropout,relu_dropout=args.relu_dropout,res_dropout=args.res_dropout,embed_dropout=args.embed_dropout) + + self.trans_v_with_l = self.get_network(self_type='vl') + + self.trans_v_with_a = TextEnhancedTransformer( + embed_dim=args.dst_feature_dims, + num_heads=args.nheads, + layers=2, attn_dropout=args.attn_dropout,relu_dropout=args.relu_dropout,res_dropout=args.res_dropout,embed_dropout=args.embed_dropout) + + self.trans_l_mem = self.get_network(self_type='l_mem', layers=2) + self.trans_a_mem = self.get_network(self_type='a_mem', layers=2) + self.trans_v_mem = self.get_network(self_type='v_mem', layers=2) + + # the post_fusion layers + self.post_fusion_dropout = nn.Dropout(p=args.post_fusion_dropout) + self.post_fusion_layer_1 = nn.Linear(6 * args.dst_feature_dims, args.post_fusion_dim) + self.post_fusion_layer_2 = nn.Linear(args.post_fusion_dim, args.post_fusion_dim) + self.post_fusion_layer_3 = nn.Linear(args.post_fusion_dim, 1) + + # the classify layer for text + self.post_text_dropout = nn.Dropout(p=args.post_text_dropout) + self.post_text_layer_1 = nn.Linear(args.dst_feature_dims, args.post_text_dim) + self.post_text_layer_2 = nn.Linear(args.post_text_dim, args.post_text_dim) + self.post_text_layer_3 = nn.Linear(args.post_text_dim, 1) + + # the classify layer for audio + self.post_audio_dropout = nn.Dropout(p=args.post_audio_dropout) + self.post_audio_layer_1 = nn.Linear(args.dst_feature_dims, args.post_audio_dim) + self.post_audio_layer_2 = nn.Linear(args.post_audio_dim, args.post_audio_dim) + self.post_audio_layer_3 = nn.Linear(args.post_audio_dim, 1) + + # the classify layer for video + self.post_video_dropout = nn.Dropout(p=args.post_video_dropout) + self.post_video_layer_1 = nn.Linear(args.dst_feature_dims, args.post_video_dim) + self.post_video_layer_2 = nn.Linear(args.post_video_dim, args.post_video_dim) + self.post_video_layer_3 = nn.Linear(args.post_video_dim, 1) + + def get_network(self, self_type='l', layers=-1): + if self_type in ['l', 'al', 'vl']: + embed_dim, attn_dropout = self.args.dst_feature_dims, self.args.attn_dropout + elif self_type in ['a', 'la', 'va']: + embed_dim, attn_dropout = self.args.dst_feature_dims, self.args.attn_dropout_a + elif self_type in ['v', 'lv', 'av']: + embed_dim, attn_dropout = self.args.dst_feature_dims, self.args.attn_dropout_v + elif self_type == 'l_mem': + embed_dim, attn_dropout = 2*self.args.dst_feature_dims, self.args.attn_dropout + elif self_type == 'a_mem': + embed_dim, attn_dropout = 2*self.args.dst_feature_dims, self.args.attn_dropout + elif self_type == 'v_mem': + embed_dim, attn_dropout = 2*self.args.dst_feature_dims, self.args.attn_dropout + else: + raise ValueError("Unknown network type") + + return TransformerEncoder(embed_dim=embed_dim, + num_heads=self.args.nheads, + layers=2, + attn_dropout=attn_dropout, + relu_dropout=self.args.relu_dropout, + res_dropout=self.args.res_dropout, + embed_dropout=self.args.embed_dropout, + attn_mask=True) + + def forward(self, text, audio, video): + audio, audio_lengths = audio + video, video_lengths = video + + mask_len = torch.sum(text[:,1,:], dim=1, keepdim=True) + text_lengths = mask_len.squeeze(1).int().detach().cpu() + + text = self.text_model(text) + + if self.aligned: + audio = self.audio_model(audio, text_lengths) + video = self.video_model(video, text_lengths) + else: + audio = self.audio_model(audio, audio_lengths) + video = self.video_model(video, video_lengths) + + text = self.proj_l(text.transpose(1,2)) + proj_x_a = audio.permute(2, 0, 1) + proj_x_v = video.permute(2, 0, 1) + proj_x_l = text.permute(2, 0, 1) + + text_h = torch.max(proj_x_l, dim=0)[0] + audio_h = torch.max(proj_x_a, dim=0)[0] + video_h = torch.max(proj_x_v, dim=0)[0] + + # (V,A) --> L + h_l_with_as = self.trans_l_with_a(proj_x_l, proj_x_a, proj_x_a) # Dimension (L, N, d_l) + h_l_with_vs = self.trans_l_with_v(proj_x_l, proj_x_v, proj_x_v) # Dimension (L, N, d_l) + h_ls = torch.cat([h_l_with_as, h_l_with_vs], dim=2) + h_ls = self.trans_l_mem(h_ls) + if type(h_ls) == tuple: + h_ls = h_ls[0] + last_h_l = h_ls[-1] # Take the last output for prediction + + # (L,V) --> A + h_a_with_ls = self.trans_a_with_l(proj_x_a, proj_x_l, proj_x_l) + h_a_with_vs = self.trans_a_with_v(proj_x_v, proj_x_a, proj_x_l) + h_as = torch.cat([h_a_with_ls, h_a_with_vs], dim=2) + h_as = self.trans_a_mem(h_as) + if type(h_as) == tuple: + h_as = h_as[0] + last_h_a = h_as[-1] + + # (L,A) --> V + h_v_with_ls = self.trans_v_with_l(proj_x_v, proj_x_l, proj_x_l) + h_v_with_as = self.trans_v_with_a(proj_x_a, proj_x_v, proj_x_l) + h_vs = torch.cat([h_v_with_ls, h_v_with_as], dim=2) + h_vs = self.trans_v_mem(h_vs) + if type(h_vs) == tuple: + h_vs = h_vs[0] + last_h_v = h_vs[-1] + + # fusion + fusion_h = torch.cat([last_h_l, last_h_a, last_h_v], dim=-1) + fusion_h = self.post_fusion_dropout(fusion_h) + fusion_h = F.relu(self.post_fusion_layer_1(fusion_h), inplace=False) + # # text + text_h = self.post_text_dropout(text_h) + text_h = F.relu(self.post_text_layer_1(text_h), inplace=False) + # audio + audio_h = self.post_audio_dropout(audio_h) + audio_h = F.relu(self.post_audio_layer_1(audio_h), inplace=False) + # vision + video_h = self.post_video_dropout(video_h) + video_h = F.relu(self.post_video_layer_1(video_h), inplace=False) + + # classifier-fusion + x_f = F.relu(self.post_fusion_layer_2(fusion_h), inplace=False) + output_fusion = self.post_fusion_layer_3(x_f) + + # classifier-text + x_t = F.relu(self.post_text_layer_2(text_h), inplace=False) + output_text = self.post_text_layer_3(x_t) + + # classifier-audio + x_a = F.relu(self.post_audio_layer_2(audio_h), inplace=False) + output_audio = self.post_audio_layer_3(x_a) + + # classifier-vision + x_v = F.relu(self.post_video_layer_2(video_h), inplace=False) + output_video = self.post_video_layer_3(x_v) + + res = { + 'M': output_fusion, + 'T': output_text, + 'A': output_audio, + 'V': output_video, + 'Feature_t': text_h, + 'Feature_a': audio_h, + 'Feature_v': video_h, + 'Feature_f': fusion_h, + } + return res + +class TextEnhancedTransformer(nn.Module): + def __init__(self, embed_dim, num_heads, layers, attn_dropout, relu_dropout, res_dropout, embed_dropout) -> None: + super().__init__() + + self.lower_mha = TransformerEncoder( + embed_dim=embed_dim, + num_heads=num_heads, + layers=1, + attn_dropout=attn_dropout, + relu_dropout=relu_dropout, + res_dropout=res_dropout, + embed_dropout=embed_dropout, + position_embedding=True, + attn_mask=True + ) + + self.upper_mha = TransformerEncoder( + embed_dim=embed_dim, + num_heads=num_heads, + layers=layers, + attn_dropout=attn_dropout, + relu_dropout=relu_dropout, + res_dropout=res_dropout, + embed_dropout=embed_dropout, + position_embedding=True, + attn_mask=True + ) + + def forward(self, query_m, key_m, text): + c = self.lower_mha(query_m, text, text) + return self.upper_mha(key_m, c, c) + +class AuViSubNet(nn.Module): + def __init__(self, in_size, hidden_size, conv1d_kernel_size, dst_feature_dims, num_layers=1, dropout=0.2, bidirectional=False): + ''' + Args: + in_size: input dimension + hidden_size: hidden layer dimension + num_layers: specify the number of layers of LSTMs. + dropout: dropout probability + bidirectional: specify usage of bidirectional LSTM + Output: + (return value in forward) a tensor of shape (batch_size, hidden_size) + ''' + super(AuViSubNet, self).__init__() + self.rnn = nn.LSTM(in_size, hidden_size, num_layers=num_layers, dropout=dropout, bidirectional=bidirectional, batch_first=True) + + self.conv = nn.Conv1d(hidden_size, dst_feature_dims, kernel_size=conv1d_kernel_size, bias=False) + + + def forward(self, x, lengths): + ''' + x: (batch_size, sequence_len, in_size) + ''' + h, _ = self.rnn(x) + h = self.conv(h.transpose(1,2)) + return h diff --git a/src/MMSA/models/multiTask/__init__.py b/src/MMSA/models/multiTask/__init__.py index 77f739f..ef6604b 100644 --- a/src/MMSA/models/multiTask/__init__.py +++ b/src/MMSA/models/multiTask/__init__.py @@ -2,3 +2,4 @@ from .MLMF import MLMF from .MTFN import MTFN from .SELF_MM import SELF_MM +from .TETFN import TETFN \ No newline at end of file diff --git a/src/MMSA/models/singleTask/CENET.py b/src/MMSA/models/singleTask/CENET.py new file mode 100644 index 0000000..843823d --- /dev/null +++ b/src/MMSA/models/singleTask/CENET.py @@ -0,0 +1,487 @@ +""" +Paper: Cross-modal Enhancement Network for Multimodal Sentiment Analysis +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +import math +import sys +from torch.nn import CrossEntropyLoss, MSELoss +from pytorch_transformers import BertConfig +from pytorch_transformers.modeling_utils import PreTrainedModel, prune_linear_layer +from pytorch_transformers import BERT_PRETRAINED_MODEL_ARCHIVE_MAP +from torch.nn import LayerNorm as BertLayerNorm + +def gelu(x): + return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) +def swish(x): + return x * torch.sigmoid(x) + +ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} +class CE(nn.Module): + def __init__(self, config, args): + super(CE, self).__init__() + TEXT_DIM = args.feature_dims[0] + AUDIO_DIM = args.feature_dims[1] + VIS_DIM = args.feature_dims[2] + self.visual_transform = nn.Sequential( + nn.Linear(VIS_DIM, config.hidden_size), + nn.ReLU(), + nn.Linear(config.hidden_size, config.hidden_size), + ) + self.acoustic_transform = nn.Sequential( + nn.Linear(AUDIO_DIM, config.hidden_size), + nn.ReLU(), + nn.Linear(config.hidden_size, config.hidden_size), + ) + self.hv = SelfAttention(TEXT_DIM) + self.ha = SelfAttention(TEXT_DIM) + self.cat_connect = nn.Linear(2 * TEXT_DIM, TEXT_DIM) + + def forward(self, text_embedding, visual=None, acoustic=None, visual_ids=None, acoustic_ids=None): + visual_ = self.visual_transform(visual) + acoustic_ = self.acoustic_transform(acoustic) + visual_ = self.hv(text_embedding, visual_) + acoustic_ = self.ha(text_embedding, acoustic_) + visual_acoustic = torch.cat((visual_, acoustic_), dim=-1) + shift = self.cat_connect(visual_acoustic) + embedding_shift = shift + text_embedding + return embedding_shift + +class Attention(nn.Module): + def __init__(self, text_dim): + super(Attention, self).__init__() + self.text_dim = text_dim + self.dim = text_dim + self.Wq = nn.Linear(text_dim, text_dim) + self.Wk = nn.Linear(self.dim, text_dim) + self.Wv = nn.Linear(self.dim, text_dim) + + def forward(self, text_embedding, embedding): + Q = self.Wq(text_embedding) + K = self.Wk(embedding) + V = self.Wv(embedding) + tmp = torch.matmul(Q, K.transpose(-1, -2) * math.sqrt(self.text_dim))[0] + weight_matrix = F.softmax(torch.matmul(Q, K.transpose(-1, -2) * math.sqrt(self.text_dim)), dim=-1) + + return torch.matmul(weight_matrix, V) + +class SelfAttention(nn.Module): + def __init__(self, hidden_size, head_num=1): + super(SelfAttention, self).__init__() + self.head_num = head_num + self.s_d = hidden_size // self.head_num + self.all_head_size = self.head_num * self.s_d + self.Wq = nn.Linear(hidden_size, hidden_size) + self.Wk = nn.Linear(hidden_size, hidden_size) + self.Wv = nn.Linear(hidden_size, hidden_size) + + def transpose_for_scores(self, x): + x = x.view(x.size(0), x.size(1), self.head_num, -1) + return x.permute(0, 2, 1, 3) + + def forward(self, text_embedding, embedding): + Q = self.Wq(text_embedding) + K = self.Wk(embedding) + V = self.Wv(embedding) + Q = self.transpose_for_scores(Q) + K = self.transpose_for_scores(K) + V = self.transpose_for_scores(V) + weight_score = torch.matmul(Q, K.transpose(-1, -2)) + weight_prob = nn.Softmax(dim=-1)(weight_score * 8) + + context_layer = torch.matmul(weight_prob, V) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + return context_layer + +class BertOutput(nn.Module): + def __init__(self, config): + super(BertOutput, self).__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + +class BertIntermediate(nn.Module): + def __init__(self, config): + super(BertIntermediate, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + +class BertLayer(nn.Module): + def __init__(self, config): + super(BertLayer, self).__init__() + self.attention = BertAttention(config) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward(self, hidden_states, attention_mask, head_mask=None): + attention_outputs = self.attention(hidden_states, attention_mask, head_mask) + attention_output = attention_outputs[0] + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them + return outputs + +class BertEncoder(nn.Module): + def __init__(self, config, args = None): + super(BertEncoder, self).__init__() + self.config = config + self.output_attentions = config.output_attentions + self.output_hidden_states = config.output_hidden_states + self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) + self.injection_index = 1 + self.CE = CE(config,args) + + def forward(self, hidden_states, visual=None, acoustic=None, visual_ids=None, acoustic_ids=None, attention_mask=None, head_mask=None): + all_hidden_states = () + all_attentions = () + for i, layer_module in enumerate(self.layer): + if self.output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if i == self.injection_index: + hidden_states = self.CE(hidden_states, visual=visual, acoustic=acoustic, visual_ids=visual_ids, acoustic_ids=acoustic_ids) + + layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i]) + hidden_states = layer_outputs[0] + + if self.output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + # Add last layer + if self.output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + outputs = (hidden_states,) + if self.output_hidden_states: + outputs = outputs + (all_hidden_states,) + if self.output_attentions: + outputs = outputs + (all_attentions,) + return outputs # last-layer hidden state, (all hidden states), (all attentions) + +class BertSelfOutput(nn.Module): + def __init__(self, config): + super(BertSelfOutput, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word, position, token_type, POS, word-level and sentence-level sentiment embeddings. + """ + def __init__(self, config, pos_tag_embedding = False, senti_embedding = False, polarity_embedding = False): + super(BertEmbeddings, self).__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + if senti_embedding: + self.senti_embeddings = nn.Embedding(3, config.hidden_size, padding_idx=2) + else: + self.register_parameter('senti_embeddings', None) + if pos_tag_embedding: + self.pos_tag_embeddings = nn.Embedding(5, config.hidden_size, padding_idx=4) + else: + self.register_parameter('pos_tag_embeddings', None) + if polarity_embedding: + self.polarity_embeddings = nn.Embedding(6, config.hidden_size, padding_idx=5) + else: + self.register_parameter('polarity_embeddings', None) + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, input_ids, token_type_ids=None, position_ids=None, pos_tag_ids=None, senti_word_ids=None, polarity_ids=None): + seq_length = input_ids.size(1) + if position_ids is None: + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + if senti_word_ids is not None and self.senti_embeddings is not None: + senti_word_embeddings = self.senti_embeddings(senti_word_ids) + else: + senti_word_embeddings = 0 + + if pos_tag_ids is not None and self.pos_tag_embeddings is not None: + pos_tag_embeddings = self.pos_tag_embeddings(pos_tag_ids) + else: + pos_tag_embeddings = 0 + + if polarity_ids is not None and self.polarity_embeddings is not None: + polarity_embeddings = self.polarity_embeddings(polarity_ids) + else: + polarity_embeddings = 0 + words_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + embeddings = words_embeddings + position_embeddings + token_type_embeddings + senti_word_embeddings + pos_tag_embeddings + polarity_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + +class BertSelfAttention(nn.Module): + def __init__(self, config): + super(BertSelfAttention, self).__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads)) + self.output_attentions = config.output_attentions + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states, attention_mask, head_mask=None): + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,) + return outputs + +class BertAttention(nn.Module): + def __init__(self, config): + super(BertAttention, self).__init__() + self.self = BertSelfAttention(config) + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size) + heads = set(heads) - self.pruned_heads # Convert to set and emove already pruned heads + for head in heads: + # Compute how many pruned heads are before the head and move the index accordingly + head = head - sum(1 if h < head else 0 for h in self.pruned_heads) + mask[head] = 0 + mask = mask.view(-1).contiguous().eq(1) + index = torch.arange(len(mask))[mask].long() + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward(self, input_tensor, attention_mask, head_mask=None): + self_outputs = self.self(input_tensor, attention_mask, head_mask) + attention_output = self.output(self_outputs[0], input_tensor) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + +class BertPooler(nn.Module): + def __init__(self, config): + super(BertPooler, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + +class BertPreTrainedModel(PreTrainedModel): + """ An abstract class to handle weights initialization and + a simple interface for dowloading and loading pretrained models. + """ + config_class = BertConfig + pretrained_model_archive_map = BERT_PRETRAINED_MODEL_ARCHIVE_MAP + base_model_prefix = "bert" + + def _init_weights(self, module): + """ Initialize the weights """ + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, BertLayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + +class BertModel(BertPreTrainedModel): + + def __init__(self,config, pos_tag_embedding=False, senti_embedding=False, polarity_embedding=False,args=None): + super(BertModel, self).__init__(config) + + self.embeddings = BertEmbeddings(config, pos_tag_embedding=pos_tag_embedding, senti_embedding=senti_embedding, polarity_embedding=polarity_embedding) + self.encoder = BertEncoder(config,args) + self.pooler = BertPooler(config) + + self.init_weights() + + def _resize_token_embeddings(self, new_num_tokens): + old_embeddings = self.embeddings.word_embeddings + new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens) + self.embeddings.word_embeddings = new_embeddings + return self.embeddings.word_embeddings + + def _prune_heads(self, heads_to_prune): + """ Prunes heads of the model. + heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + See base class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def forward(self, input_ids, visual=None, acoustic=None, visual_ids=None, acoustic_ids=None, pos_ids=None, senti_word_ids=None, polarity_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None): + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + if head_mask is not None: + if head_mask.dim() == 1: + head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1) + elif head_mask.dim() == 2: + head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) + head_mask = head_mask.to(dtype=next(self.parameters()).dtype) + else: + head_mask = [None] * self.config.num_hidden_layers + with torch.no_grad(): + embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids, pos_tag_ids=pos_ids, + senti_word_ids=senti_word_ids, polarity_ids=polarity_ids) + encoder_outputs = self.encoder(embedding_output, visual, acoustic, visual_ids, acoustic_ids, + attention_mask=extended_attention_mask, + head_mask=head_mask) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) + + outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] + return outputs + +class BertClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, config): + super(BertClassificationHead, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.out_proj = nn.Linear(config.hidden_size, config.num_labels) + + def forward(self, features, visual=None, visual_ids=None, **kwargs): + x = features[:, 0, :] # take token (equiv. to [CLS]) + x = self.dropout(x) + x = self.dense(x) + x = torch.tanh(x) + x = self.dropout(x) + x = self.out_proj(x) + return x + +class CENET(BertPreTrainedModel): + config_class = BertConfig + pretrained_model_archive_map = BERT_PRETRAINED_MODEL_ARCHIVE_MAP + base_model_prefix = "bert" + def __init__(self,config, pos_tag_embedding=False, senti_embedding=False, polarity_embedding=False,args= None): + super(CENET, self).__init__(config) + self.num_labels = config.num_labels + self.bert = BertModel(config, pos_tag_embedding=pos_tag_embedding, + senti_embedding=senti_embedding, + polarity_embedding=polarity_embedding,args=args) + self.classifier = BertClassificationHead(config) + self.init_weights() + + def forward(self,text, acoustic, visual, visual_ids=None, acoustic_ids=None, pos_tag_ids=None, senti_word_ids=None, polarity_ids= None, position_ids=None, head_mask= None, labels=None): + input_ids = text[:,0,:].long() + attention_mask =text[:,1,:].long() + token_type_ids = text[:,2,:].long() + outputs = self.bert(input_ids, + visual=visual, + acoustic=acoustic, + visual_ids=visual_ids, + acoustic_ids=acoustic_ids, + pos_ids=pos_tag_ids, + senti_word_ids=senti_word_ids, + polarity_ids=polarity_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + ) + sequence_output = outputs[0] + logits = self.classifier(sequence_output) + if labels is not None: + if self.num_labels == 1: + loss_fct = MSELoss() + loss = loss_fct(logits.view(-1), labels.view(-1)) + else: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + outputs = (loss,) + outputs + outputs = (logits,) + outputs[2:] + return outputs + diff --git a/src/MMSA/models/singleTask/__init__.py b/src/MMSA/models/singleTask/__init__.py index 68c25eb..8783892 100644 --- a/src/MMSA/models/singleTask/__init__.py +++ b/src/MMSA/models/singleTask/__init__.py @@ -9,4 +9,5 @@ from .MISA import MISA from .MFM import MFM from .MMIM import MMIM -from .MCTN import MCTN \ No newline at end of file +from .MCTN import MCTN +from .CENET import CENET \ No newline at end of file diff --git a/src/MMSA/models/subNets/AlignNets.py b/src/MMSA/models/subNets/AlignNets.py index 9620177..1eed605 100644 --- a/src/MMSA/models/subNets/AlignNets.py +++ b/src/MMSA/models/subNets/AlignNets.py @@ -85,10 +85,10 @@ def align(x): pad_len = self.dst_len - raw_seq_len % self.dst_len pool_size = raw_seq_len // self.dst_len + 1 pad_x = x[:, -1, :].unsqueeze(1).expand([x.size(0), pad_len, x.size(-1)]) - x = torch.cat([x, pad_x], dim=1).view(x.size(0), pool_size, self.dst_len, -1) - x = x.mean(dim=1) + x = torch.cat([x, pad_x], dim=1).view(x.size(0), self.dst_len, pool_size, -1) + x = x.mean(dim=2) return x - text_x = align(text_x) + # text_x = align(text_x) audio_x = align(audio_x) video_x = align(video_x) return text_x, audio_x, video_x @@ -101,6 +101,20 @@ def __conv1d(self, text_x, audio_x, video_x): def forward(self, text_x, audio_x, video_x): # already aligned - if text_x.size(1) == audio_x.size(1) == video_x.size(1): + + # The input audio and video are of the same type + if isinstance(audio_x, tuple): + audio = audio_x[0] + video = video_x[0] + else: + audio = audio_x + video = video_x + + if self.dst_len == audio.size(1) == video.size(1): return text_x, audio_x, video_x - return self.ALIGN_WAY[self.mode](text_x, audio_x, video_x) \ No newline at end of file + result_tmp = self.ALIGN_WAY[self.mode](text_x, audio, video) + + if isinstance(audio_x, tuple): + return (result_tmp[0],(result_tmp[1],audio_x[1]),(result_tmp[2],video_x[1])) + else: + return result_tmp \ No newline at end of file diff --git a/src/MMSA/trains/ATIO.py b/src/MMSA/trains/ATIO.py index b96c2dc..a4287b8 100644 --- a/src/MMSA/trains/ATIO.py +++ b/src/MMSA/trains/ATIO.py @@ -23,11 +23,13 @@ def __init__(self): 'misa': MISA, 'mfm': MFM, 'mmim': MMIM, + 'cenet':CENET, # multi-task 'mtfn': MTFN, 'mlmf': MLMF, 'mlf_dnn': MLF_DNN, 'self_mm': SELF_MM, + 'tetfn': TETFN, # missing-task 'tfr_net': TFR_NET, } diff --git a/src/MMSA/trains/multiTask/MLF_DNN.py b/src/MMSA/trains/multiTask/MLF_DNN.py index e3134cf..86308f0 100644 --- a/src/MMSA/trains/multiTask/MLF_DNN.py +++ b/src/MMSA/trains/multiTask/MLF_DNN.py @@ -77,13 +77,11 @@ def do_train(self, model, dataloader, return_epoch_results=False): y_true[m].append(labels['M'].cpu()) train_loss = train_loss / len(dataloader['train']) - logger.info( - f"TRAIN-({self.args.model_name}) [{epochs - best_epoch}/{epochs}/{self.args.cur_seed}] >> loss: {round(train_loss, 4)} {dict_to_str(train_results)}" - ) for m in self.args.tasks: pred, true = torch.cat(y_pred[m]), torch.cat(y_true[m]) train_results = self.metrics(pred, true) logger.info('%s: >> ' %(m) + dict_to_str(train_results)) + # validation val_results = self.do_test(model, dataloader['valid'], mode="VAL") cur_valid = val_results[self.args.KeyEval] diff --git a/src/MMSA/trains/multiTask/MLMF.py b/src/MMSA/trains/multiTask/MLMF.py index e2f69ee..8f9c137 100644 --- a/src/MMSA/trains/multiTask/MLMF.py +++ b/src/MMSA/trains/multiTask/MLMF.py @@ -79,9 +79,6 @@ def do_train(self, model, dataloader, return_epoch_results=False): y_true[m].append(labels['M'].cpu()) train_loss = train_loss / len(dataloader['train']) - logger.info( - f"TRAIN-({self.args.model_name}) [{epochs - best_epoch}/{epochs}/{self.args.cur_seed}] >> loss: {round(train_loss, 4)} {dict_to_str(train_results)}" - ) for m in self.args.tasks: pred, true = torch.cat(y_pred[m]), torch.cat(y_true[m]) train_results = self.metrics(pred, true) diff --git a/src/MMSA/trains/multiTask/MTFN.py b/src/MMSA/trains/multiTask/MTFN.py index 0f45231..51c70d0 100644 --- a/src/MMSA/trains/multiTask/MTFN.py +++ b/src/MMSA/trains/multiTask/MTFN.py @@ -78,9 +78,6 @@ def do_train(self, model, dataloader, return_epoch_results=False): y_true[m].append(labels['M'].cpu()) train_loss = train_loss / len(dataloader['train']) - logger.info( - f"TRAIN-({self.args.model_name}) [{epochs - best_epoch}/{epochs}/{self.args.cur_seed}] >> loss: {round(train_loss, 4)} {dict_to_str(train_results)}" - ) for m in self.args.tasks: pred, true = torch.cat(y_pred[m]), torch.cat(y_true[m]) train_results = self.metrics(pred, true) diff --git a/src/MMSA/trains/multiTask/SELF_MM.py b/src/MMSA/trains/multiTask/SELF_MM.py index fcdbd6a..2a7d008 100644 --- a/src/MMSA/trains/multiTask/SELF_MM.py +++ b/src/MMSA/trains/multiTask/SELF_MM.py @@ -172,9 +172,7 @@ def do_train(self, model, dataloader, return_epoch_results=False): # update optimizer.step() train_loss = train_loss / len(dataloader['train']) - # logger.info( - # f"TRAIN-({self.args.model_name}) [{epochs - best_epoch}/{epochs}/{self.args.cur_seed}] >> loss: {round(train_loss, 4)} {dict_to_str(train_results)}" - # ) + for m in self.args.tasks: pred, true = torch.cat(y_pred[m]), torch.cat(y_true[m]) train_results = self.metrics(pred, true) diff --git a/src/MMSA/trains/multiTask/TETFN.py b/src/MMSA/trains/multiTask/TETFN.py new file mode 100644 index 0000000..51086c1 --- /dev/null +++ b/src/MMSA/trains/multiTask/TETFN.py @@ -0,0 +1,332 @@ +import logging +import os +import pickle as plk +import numpy as np +import torch +from torch import optim +from tqdm import tqdm +from ...utils import MetricsTop, dict_to_str + +logger = logging.getLogger('MMSA') + +class TETFN(): + def __init__(self, args): + assert args.train_mode == 'regression' + + self.args = args + self.args.tasks = "MTAV" + self.metrics = MetricsTop(args.train_mode).getMetics(args.dataset_name) + + self.feature_map = { + 'fusion': torch.zeros(args.train_samples, args.post_fusion_dim, requires_grad=False).to(args.device), + 'text': torch.zeros(args.train_samples, args.post_text_dim, requires_grad=False).to(args.device), + 'audio': torch.zeros(args.train_samples, args.post_audio_dim, requires_grad=False).to(args.device), + 'vision': torch.zeros(args.train_samples, args.post_video_dim, requires_grad=False).to(args.device), + } + + self.center_map = { + 'fusion': { + 'pos': torch.zeros(args.post_fusion_dim, requires_grad=False).to(args.device), + 'neg': torch.zeros(args.post_fusion_dim, requires_grad=False).to(args.device), + }, + 'text': { + 'pos': torch.zeros(args.post_text_dim, requires_grad=False).to(args.device), + 'neg': torch.zeros(args.post_text_dim, requires_grad=False).to(args.device), + }, + 'audio': { + 'pos': torch.zeros(args.post_audio_dim, requires_grad=False).to(args.device), + 'neg': torch.zeros(args.post_audio_dim, requires_grad=False).to(args.device), + }, + 'vision': { + 'pos': torch.zeros(args.post_video_dim, requires_grad=False).to(args.device), + 'neg': torch.zeros(args.post_video_dim, requires_grad=False).to(args.device), + } + } + + self.dim_map = { + 'fusion': torch.tensor(args.post_fusion_dim).float(), + 'text': torch.tensor(args.post_text_dim).float(), + 'audio': torch.tensor(args.post_audio_dim).float(), + 'vision': torch.tensor(args.post_video_dim).float(), + } + # new labels + self.label_map = { + 'fusion': torch.zeros(args.train_samples, requires_grad=False).to(args.device), + 'text': torch.zeros(args.train_samples, requires_grad=False).to(args.device), + 'audio': torch.zeros(args.train_samples, requires_grad=False).to(args.device), + 'vision': torch.zeros(args.train_samples, requires_grad=False).to(args.device) + } + + self.name_map = { + 'M': 'fusion', + 'T': 'text', + 'A': 'audio', + 'V': 'vision' + } + + def do_train(self, model, dataloader, return_epoch_results=False): + bert_no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] + bert_params = list(model.Model.text_model.named_parameters()) + audio_params = list(model.Model.audio_model.named_parameters()) + video_params = list(model.Model.video_model.named_parameters()) + + bert_params_decay = [p for n, p in bert_params if not any(nd in n for nd in bert_no_decay)] + bert_params_no_decay = [p for n, p in bert_params if any(nd in n for nd in bert_no_decay)] + audio_params = [p for n, p in audio_params] + video_params = [p for n, p in video_params] + model_params_other = [p for n, p in list(model.Model.named_parameters()) if 'text_model' not in n and \ + 'audio_model' not in n and 'video_model' not in n] + + optimizer_grouped_parameters = [ + {'params': bert_params_decay, 'weight_decay': self.args.weight_decay_bert, 'lr': self.args.learning_rate_bert}, + {'params': bert_params_no_decay, 'weight_decay': 0.0, 'lr': self.args.learning_rate_bert}, + {'params': audio_params, 'weight_decay': self.args.weight_decay_audio, 'lr': self.args.learning_rate_audio}, + {'params': video_params, 'weight_decay': self.args.weight_decay_video, 'lr': self.args.learning_rate_video}, + {'params': model_params_other, 'weight_decay': self.args.weight_decay_other, 'lr': self.args.learning_rate_other} + ] + optimizer = optim.Adam(optimizer_grouped_parameters) + + saved_labels = {} + # init labels + logger.info("Init labels") + with tqdm(dataloader['train']) as td: + for batch_data in td: + labels_m = batch_data['labels']['M'].view(-1).to(self.args.device) + indexes = batch_data['index'].view(-1) + self.init_labels(indexes, labels_m) + + # initilize results + logger.info("Start training") + epochs, best_epoch = 0, 0 + if return_epoch_results: + epoch_results = { + 'train': [], + 'valid': [], + 'test': [] + } + min_or_max = 'min' if self.args.KeyEval in ['Loss'] else 'max' + best_valid = 1e8 if min_or_max == 'min' else 0 + # loop util earlystop + while True: + epochs += 1 + # train + y_pred = {'M': [], 'T': [], 'A': [], 'V': []} + y_true = {'M': [], 'T': [], 'A': [], 'V': []} + losses = [] + model.train() + train_loss = 0.0 + left_epochs = self.args.update_epochs + ids = [] + with tqdm(dataloader['train']) as td: + for batch_data in td: + if left_epochs == self.args.update_epochs: + optimizer.zero_grad() + left_epochs -= 1 + + vision = batch_data['vision'].to(self.args.device) + audio = batch_data['audio'].to(self.args.device) + text = batch_data['text'].to(self.args.device) + indexes = batch_data['index'].view(-1) + cur_id = batch_data['id'] + ids.extend(cur_id) + + if not self.args.need_data_aligned: + audio_lengths = batch_data['audio_lengths'] + vision_lengths = batch_data['vision_lengths'] + else: + audio_lengths, vision_lengths = 0, 0 + + # forward + outputs = model(text, (audio, audio_lengths), (vision, vision_lengths)) + # store results + for m in self.args.tasks: + y_pred[m].append(outputs[m].cpu()) + y_true[m].append(self.label_map[self.name_map[m]][indexes].cpu()) + # compute loss + loss = 0.0 + for m in self.args.tasks: + loss += self.weighted_loss(outputs[m], self.label_map[self.name_map[m]][indexes], \ + indexes=indexes, mode=self.name_map[m]) + # backward + loss.backward() + train_loss += loss.item() + # update features + f_fusion = outputs['Feature_f'].detach() + f_text = outputs['Feature_t'].detach() + f_audio = outputs['Feature_a'].detach() + f_vision = outputs['Feature_v'].detach() + if epochs > 1: + self.update_labels(f_fusion, f_text, f_audio, f_vision, epochs, indexes, outputs) + + self.update_features(f_fusion, f_text, f_audio, f_vision, indexes) + self.update_centers() + + # update parameters + if not left_epochs: + # update + optimizer.step() + left_epochs = self.args.update_epochs + if not left_epochs: + # update + optimizer.step() + train_loss = train_loss / len(dataloader['train']) + # logger.info( + # f"TRAIN-({self.args.model_name}) [{epochs - best_epoch}/{epochs}/{self.args.cur_seed}] >> loss: {round(train_loss, 4)} {dict_to_str(train_results)}" + # ) + for m in self.args.tasks: + pred, true = torch.cat(y_pred[m]), torch.cat(y_true[m]) + train_results = self.metrics(pred, true) + logger.info('%s: >> ' %(m) + dict_to_str(train_results)) + # validation + val_results = self.do_test(model, dataloader['valid'], mode="VAL") + cur_valid = val_results[self.args.KeyEval] + # save best model + isBetter = cur_valid <= (best_valid - 1e-6) if min_or_max == 'min' else cur_valid >= (best_valid + 1e-6) + if isBetter: + best_valid, best_epoch = cur_valid, epochs + # save model + torch.save(model.cpu().state_dict(), self.args.model_save_path) + model.to(self.args.device) + # save labels + if self.args.save_labels: + tmp_save = {k: v.cpu().numpy() for k, v in self.label_map.items()} + tmp_save['ids'] = ids + saved_labels[epochs] = tmp_save + # epoch results + if return_epoch_results: + train_results["Loss"] = train_loss + epoch_results['train'].append(train_results) + epoch_results['valid'].append(val_results) + test_results = self.do_test(model, dataloader['test'], mode="TEST") + epoch_results['test'].append(test_results) + # early stop + if epochs - best_epoch >= self.args.early_stop: + if self.args.save_labels: + with open(os.path.join(self.args.res_save_dir, f'{self.args.model_name}-{self.args.dataset_name}-labels.pkl'), 'wb') as df: + plk.dump(saved_labels, df, protocol=4) + return epoch_results if return_epoch_results else None + + def do_test(self, model, dataloader, mode="VAL", return_sample_results=False): + model.eval() + y_pred = {'M': [], 'T': [], 'A': [], 'V': []} + y_true = {'M': [], 'T': [], 'A': [], 'V': []} + eval_loss = 0.0 + if return_sample_results: + ids, sample_results = [], [] + all_labels = [] + features = { + "Feature_t": [], + "Feature_a": [], + "Feature_v": [], + "Feature_f": [], + } + # criterion = nn.L1Loss() + with torch.no_grad(): + with tqdm(dataloader) as td: + for batch_data in td: + vision = batch_data['vision'].to(self.args.device) + audio = batch_data['audio'].to(self.args.device) + text = batch_data['text'].to(self.args.device) + if not self.args.need_data_aligned: + audio_lengths = batch_data['audio_lengths'] + vision_lengths = batch_data['vision_lengths'] + else: + audio_lengths, vision_lengths = 0, 0 + + labels_m = batch_data['labels']['M'].to(self.args.device).view(-1) + outputs = model(text, (audio, audio_lengths), (vision, vision_lengths)) + + if return_sample_results: + ids.extend(batch_data['id']) + for item in features.keys(): + features[item].append(outputs[item].cpu().detach().numpy()) + all_labels.extend(labels_m.cpu().detach().tolist()) + preds = outputs["M"].cpu().detach().numpy() + # test_preds_i = np.argmax(preds, axis=1) + sample_results.extend(preds.squeeze()) + + loss = self.weighted_loss(outputs['M'], labels_m) + eval_loss += loss.item() + y_pred['M'].append(outputs['M'].cpu()) + y_true['M'].append(labels_m.cpu()) + eval_loss = eval_loss / len(dataloader) + logger.info(mode+"-(%s)" % self.args.model_name + " >> loss: %.4f " % eval_loss) + pred, true = torch.cat(y_pred['M']), torch.cat(y_true['M']) + eval_results = self.metrics(pred, true) + logger.info('M: >> ' + dict_to_str(eval_results)) + eval_results['Loss'] = round(eval_loss, 4) + + if return_sample_results: + eval_results["Ids"] = ids + eval_results["SResults"] = sample_results + for k in features.keys(): + features[k] = np.concatenate(features[k], axis=0) + eval_results['Features'] = features + eval_results['Labels'] = all_labels + + return eval_results + + def weighted_loss(self, y_pred, y_true, indexes=None, mode='fusion'): + y_pred = y_pred.view(-1) + y_true = y_true.view(-1) + if mode == 'fusion': + weighted = torch.ones_like(y_pred) + else: + weighted = torch.tanh(torch.abs(self.label_map[mode][indexes] - self.label_map['fusion'][indexes])) + loss = torch.mean(weighted * torch.abs(y_pred - y_true)) + return loss + + def update_features(self, f_fusion, f_text, f_audio, f_vision, indexes): + self.feature_map['fusion'][indexes] = f_fusion + self.feature_map['text'][indexes] = f_text + self.feature_map['audio'][indexes] = f_audio + self.feature_map['vision'][indexes] = f_vision + + def update_centers(self): + def update_single_center(mode): + neg_indexes = self.label_map[mode] < 0 + if self.args.excludeZero: + pos_indexes = self.label_map[mode] > 0 + else: + pos_indexes = self.label_map[mode] >= 0 + self.center_map[mode]['pos'] = torch.mean(self.feature_map[mode][pos_indexes], dim=0) + self.center_map[mode]['neg'] = torch.mean(self.feature_map[mode][neg_indexes], dim=0) + + update_single_center(mode='fusion') + update_single_center(mode='text') + update_single_center(mode='audio') + update_single_center(mode='vision') + + def init_labels(self, indexes, m_labels): + self.label_map['fusion'][indexes] = m_labels + self.label_map['text'][indexes] = m_labels + self.label_map['audio'][indexes] = m_labels + self.label_map['vision'][indexes] = m_labels + + def update_labels(self, f_fusion, f_text, f_audio, f_vision, cur_epoches, indexes, outputs): + MIN = 1e-8 + def update_single_label(f_single, mode): + d_sp = torch.norm(f_single - self.center_map[mode]['pos'], dim=-1) + d_sn = torch.norm(f_single - self.center_map[mode]['neg'], dim=-1) + delta_s = (d_sn - d_sp) / (d_sp + MIN) + # d_s_pn = torch.norm(self.center_map[mode]['pos'] - self.center_map[mode]['neg'], dim=-1) + # delta_s = (d_sn - d_sp) / (d_s_pn + MIN) + alpha = delta_s / (delta_f + MIN) + + new_labels = 0.5 * alpha * self.label_map['fusion'][indexes] + \ + 0.5 * (self.label_map['fusion'][indexes] + delta_s - delta_f) + new_labels = torch.clamp(new_labels, min=-self.args.H, max=self.args.H) + # new_labels = torch.tanh(new_labels) * self.args.H + + n = cur_epoches + self.label_map[mode][indexes] = (n - 1) / (n + 1) * self.label_map[mode][indexes] + 2 / (n + 1) * new_labels + + d_fp = torch.norm(f_fusion - self.center_map['fusion']['pos'], dim=-1) + d_fn = torch.norm(f_fusion - self.center_map['fusion']['neg'], dim=-1) + # d_f_pn = torch.norm(self.center_map['fusion']['pos'] - self.center_map['fusion']['neg'], dim=-1) + # delta_f = (d_fn - d_fp) / (d_f_pn + MIN) + delta_f = (d_fn - d_fp) / (d_fp + MIN) + + update_single_label(f_text, mode='text') + update_single_label(f_audio, mode='audio') + update_single_label(f_vision, mode='vision') diff --git a/src/MMSA/trains/multiTask/__init__.py b/src/MMSA/trains/multiTask/__init__.py index 77f739f..b9adc38 100644 --- a/src/MMSA/trains/multiTask/__init__.py +++ b/src/MMSA/trains/multiTask/__init__.py @@ -2,3 +2,4 @@ from .MLMF import MLMF from .MTFN import MTFN from .SELF_MM import SELF_MM +from .TETFN import TETFN diff --git a/src/MMSA/trains/singleTask/CENET.py b/src/MMSA/trains/singleTask/CENET.py new file mode 100644 index 0000000..5dbd191 --- /dev/null +++ b/src/MMSA/trains/singleTask/CENET.py @@ -0,0 +1,125 @@ +import logging +from tqdm import tqdm +import torch +import torch.nn as nn +from ...utils import MetricsTop, dict_to_str +from transformers import BertTokenizer + +logger = logging.getLogger('MMSA') + +class CENET(): + def __init__(self, args): + self.args = args + self.args.max_grad_norm = 2 + self.metrics = MetricsTop(args.train_mode).getMetics(args.dataset_name) + self.tokenizer = BertTokenizer.from_pretrained(args.pretrained) + self.criterion = nn.L1Loss() + def do_train(self, model, dataloader,return_epoch_results=False): + param_optimizer = list(model.named_parameters()) + no_decay = ["bias", "LayerNorm.weight"] + CE_params = ['CE'] + if return_epoch_results: + epoch_results = {'train': [],'valid': [],'test': []} + optimizer_grouped_parameters = [ + { + "params": [ + p for n, p in param_optimizer if not any(nd in n for nd in no_decay) and not any(nd in n for nd in CE_params) + ], + "weight_decay": self.args.weight_decay, + }, + {"params": model.Model.bert.encoder.CE.parameters(), 'lr':self.args.learning_rate, "weight_decay": self.args.weight_decay}, + { + "params": [ + p for n, p in param_optimizer if any(nd in n for nd in no_decay) and not any(nd in n for nd in CE_params) + ], + "weight_decay": 0.0, + }, + ] + optimizer = torch.optim.Adam(optimizer_grouped_parameters, lr=self.args.learning_rate, eps=self.args.adam_epsilon) + epochs, best_epoch = 0, 0 + min_or_max = 'min' if self.args.KeyEval in ['Loss'] else 'max' + best_valid = 1e8 if min_or_max == 'min' else 0 + while True: + epochs += 1 + y_pred = [] + y_true = [] + model.train() + train_loss = 0.0 + with tqdm(dataloader['train']) as td: + for index,batch_data in enumerate(td): + loss = 0.0 + vision = batch_data['vision'].to(self.args.device) + audio = batch_data['audio'].to(self.args.device) + text = batch_data['text'].to(self.args.device) + labels = batch_data['labels']['M'] + labels = labels.to(self.args.device).view(-1, 1) + optimizer.zero_grad() + outputs = model(text,audio,vision) + logits = outputs[0] + loss += self.criterion(logits, labels) + loss.backward() + if self.args.max_grad_norm != -1.0: + torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm) + optimizer.step() + logits = logits.detach().cpu() + labels = labels.detach().cpu() + train_loss += loss.item() + y_pred.append(logits) + y_true.append(labels) + train_loss = train_loss / len(dataloader['train']) + logger.info("TRAIN-(%s) (%d/%d)>> loss: %.4f " % (self.args.model_name, \ + epochs - best_epoch, epochs, train_loss)) + pred, true = torch.cat(y_pred), torch.cat(y_true) + train_results = self.metrics(pred, true) + logger.info('%s: >> ' %('Multimodal') + dict_to_str(train_results)) + # validation + val_results = self.do_test(model, dataloader['valid'], mode="VAL") + cur_valid = val_results[self.args.KeyEval] + isBetter = cur_valid <= (best_valid - 1e-6) if min_or_max == 'min' else cur_valid >= (best_valid + 1e-6) + # save best model + if isBetter: + best_valid, best_epoch = cur_valid, epochs + # save model + torch.save(model.cpu().state_dict(), self.args.model_save_path ) + model.to(self.args.device) + # early stop + if return_epoch_results: + train_results["Loss"] = train_loss + epoch_results['train'].append(train_results) + epoch_results['valid'].append(val_results) + test_results = self.do_test(model, dataloader['test'], mode="TEST") + epoch_results['test'].append(test_results) + # early stop + if epochs - best_epoch >= self.args.early_stop: + return epoch_results if return_epoch_results else None + + def do_test(self, model, dataloader, mode="VAL"): + model.eval() + y_pred = [] + y_true = [] + eval_loss = 0.0 + with torch.no_grad(): + with tqdm(dataloader) as td: + for batch_data in td: + loss = 0.0 + vision = batch_data['vision'].to(self.args.device) + audio = batch_data['audio'].to(self.args.device) + text = batch_data['text'].to(self.args.device) + labels = batch_data['labels']['M'] + labels = labels.to(self.args.device).view(-1, 1) + outputs = model(text,audio,vision) + logits = outputs[0] + loss += self.criterion(logits, labels) + eval_loss += loss.item() + logits = logits.detach().cpu() + labels = labels.detach().cpu() + y_pred.append(logits) + y_true.append(labels) + eval_loss = round(eval_loss / len(dataloader), 4) + logger.info(mode+"-(%s)" % self.args.model_name + " >> loss: %.4f " % eval_loss) + pred, true = torch.cat(y_pred), torch.cat(y_true) + results = self.metrics(pred, true) + logger.info('%s: >> ' %('Multimodal') + dict_to_str(results)) + eval_results = results + eval_results['Loss'] = eval_loss + return eval_results \ No newline at end of file diff --git a/src/MMSA/trains/singleTask/__init__.py b/src/MMSA/trains/singleTask/__init__.py index 68c25eb..8783892 100644 --- a/src/MMSA/trains/singleTask/__init__.py +++ b/src/MMSA/trains/singleTask/__init__.py @@ -9,4 +9,5 @@ from .MISA import MISA from .MFM import MFM from .MMIM import MMIM -from .MCTN import MCTN \ No newline at end of file +from .MCTN import MCTN +from .CENET import CENET \ No newline at end of file From de4c48fdde6cd0ea908ce525fa8b8257db0d95c6 Mon Sep 17 00:00:00 2001 From: cherishPre <2649458107@qq.com> Date: Wed, 20 Dec 2023 11:46:19 +0800 Subject: [PATCH 2/2] TETFN and CENET models added --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index f90d180..2561550 100644 --- a/README.md +++ b/README.md @@ -145,8 +145,8 @@ MMSA uses feature files that are organized as follows: | Multi-Task | [MLF_DNN](src/MMSA/models/multiTask/MLF_DNN.py) | [MMSA](https://github.com/thuiar/MMSA) | ACL 2020 | | Multi-Task | [MTFN](src/MMSA/models/multiTask/MTFN.py) | [MMSA](https://github.com/thuiar/MMSA) | ACL 2020 | | Multi-Task | [MLMF](src/MMSA/models/multiTask/MLMF.py) | [MMSA](https://github.com/thuiar/MMSA) | ACL 2020 | -| Multi-Task | [SELF_MM](src/MMSA/models/multiTask/SELF_MM.py) | [Self-MM](https://github.com/thuiar/Self-MM) | AAAI 2021 -| Multi-Task | [TETFN](src/MMSA/models/multiTask/TETFN.py) | TETFN | PR 2023 +| Multi-Task | [SELF_MM](src/MMSA/models/multiTask/SELF_MM.py) | [Self-MM](https://github.com/thuiar/Self-MM) | AAAI 2021 | +| Multi-Task | [TETFN](src/MMSA/models/multiTask/TETFN.py) | TETFN | PR 2023 | | Single-Task | [BERT-MAG](src/MMSA/models/singleTask/BERT_MAG.py) | [MAG-BERT](https://github.com/WasifurRahman/BERT_multimodal_transformer) | ACL 2020 | | Single-Task | [MISA](src/MMSA/models/singleTask/MISA.py) | [MISA](https://github.com/declare-lab/MISA) | ACMMM 2020 | | | Single-Task | [MMIM](src/MMSA/models/singleTask/MMIM.py) | [MMIM](https://github.com/declare-lab/Multimodal-Infomax) | EMNLP 2021 |