diff --git a/HBI/models/modeling.py b/HBI/models/modeling.py index be848e0..a98e0d0 100644 --- a/HBI/models/modeling.py +++ b/HBI/models/modeling.py @@ -225,6 +225,10 @@ def forward(self, text_ids, text_mask, video, video_mask=None, idx=None, global_ banzhaf = self.banzhafmodel(logits.unsqueeze(1)).squeeze(1) with torch.no_grad(): teacher = self.banzhafteacher(logits.unsqueeze(1).clone().detach()).squeeze(1).detach() + teacher = torch.einsum('btv,bt->btv', [teacher, text_mask]) + teacher = torch.einsum('btv,bv->btv', [teacher, video_mask]) + banzhaf = torch.einsum('btv,bt->btv', [banzhaf, text_mask]) + banzhaf = torch.einsum('btv,bv->btv', [banzhaf, video_mask]) s_loss = self.kl(banzhaf, teacher) + self.kl(banzhaf.T, teacher.T) loss += M_loss + self.config.skl * s_loss diff --git a/HBI/models/modeling_estimator.py b/HBI/models/modeling_estimator.py new file mode 100644 index 0000000..2722098 --- /dev/null +++ b/HBI/models/modeling_estimator.py @@ -0,0 +1,378 @@ +import os +from collections import OrderedDict +from types import SimpleNamespace +import torch +from torch import nn +from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence +import torch.nn.functional as F +from .module_clip import CLIP, convert_weights, _PT_NAME +from .module_cross import CrossModel, Transformer as TransformerClip +from .until_module import LayerNorm, AllGather, AllGather2, CrossEn, MSE, ArcCrossEn, KL +import numpy as np +from .banzhaf import BanzhafModule, BanzhafInteraction +from .cluster import CTM, TCBlock + +allgather = AllGather.apply +allgather2 = AllGather2.apply + + +class ResidualLinear(nn.Module): + def __init__(self, d_int: int): + super(ResidualLinear, self).__init__() + + self.fc_relu = nn.Sequential(nn.Linear(d_int, d_int), + nn.ReLU(inplace=True)) + + def forward(self, x): + x = x + self.fc_relu(x) + return x + + +class HBI(nn.Module): + def __init__(self, config): + super(HBI, self).__init__() + + self.config = config + self.interaction = config.interaction + self.agg_module = getattr(config, 'agg_module', 'meanP') + backbone = getattr(config, 'base_encoder', "ViT-B/32") + + assert backbone in _PT_NAME + model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), _PT_NAME[backbone]) + if os.path.exists(model_path): + FileNotFoundError + try: + # loading JIT archive + model = torch.jit.load(model_path, map_location="cpu").eval() + state_dict = model.state_dict() + except RuntimeError: + state_dict = torch.load(model_path, map_location="cpu") + + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len( + [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) + image_resolution = vision_patch_size * grid_size + + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) + + self.clip = CLIP(embed_dim, image_resolution, vision_layers, vision_width, vision_patch_size, + context_length, vocab_size, transformer_width, transformer_heads, transformer_layers) + + if torch.cuda.is_available(): + convert_weights(self.clip) # fp16 + + cross_config = SimpleNamespace(**{ + "attention_probs_dropout_prob": 0.1, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 512, + "initializer_range": 0.02, + "intermediate_size": 2048, + "max_position_embeddings": 128, + "num_attention_heads": 8, + "num_hidden_layers": 4, + "vocab_size": 512, + "soft_t": 0.07, + }) + cross_config.max_position_embeddings = context_length + cross_config.hidden_size = transformer_width + self.cross_config = cross_config + if self.interaction == 'wti': + self.text_weight_fc = nn.Sequential( + nn.Linear(transformer_width, 2 * transformer_width), nn.ReLU(inplace=True), + nn.Linear(2 * transformer_width, 1)) + self.video_weight_fc = nn.Sequential( + nn.Linear(transformer_width, 2 * transformer_width), nn.ReLU(inplace=True), + nn.Linear(2 * transformer_width, 1)) + + self.text_weight_fc0 = nn.Sequential( + nn.Linear(transformer_width, 2 * transformer_width), nn.ReLU(inplace=True), + nn.Linear(2 * transformer_width, 1)) + self.video_weight_fc0 = nn.Sequential( + nn.Linear(transformer_width, 2 * transformer_width), nn.ReLU(inplace=True), + nn.Linear(2 * transformer_width, 1)) + + self.text_weight_fc1 = nn.Sequential( + nn.Linear(transformer_width, 2 * transformer_width), nn.ReLU(inplace=True), + nn.Linear(2 * transformer_width, 1)) + self.video_weight_fc1 = nn.Sequential( + nn.Linear(transformer_width, 2 * transformer_width), nn.ReLU(inplace=True), + nn.Linear(2 * transformer_width, 1)) + + if self.agg_module in ["seqLSTM", "seqTransf"]: + self.frame_position_embeddings = nn.Embedding(cross_config.max_position_embeddings, + cross_config.hidden_size) + if self.agg_module == "seqTransf": + self.transformerClip = TransformerClip(width=transformer_width, + layers=config.num_hidden_layers, + heads=transformer_heads) + if self.agg_module == "seqLSTM": + self.lstm_visual = nn.LSTM(input_size=cross_config.hidden_size, hidden_size=cross_config.hidden_size, + batch_first=True, bidirectional=False, num_layers=1) + + self.loss_fct = CrossEn(config) + self.loss_arcfct = ArcCrossEn(margin=10) + self.banzhafteacher = BanzhafModule(64) + self.banzhafinteraction = BanzhafInteraction(config.max_words, config.max_frames, 100) + + self.apply(self.init_weights) # random init must before loading pretrain + self.clip.load_state_dict(state_dict, strict=False) + + self.mse = MSE() + self.kl = KL() + + ## ===> Initialization trick [HARD CODE] + new_state_dict = OrderedDict() + + if self.agg_module in ["seqLSTM", "seqTransf"]: + contain_frame_position = False + for key in state_dict.keys(): + if key.find("frame_position_embeddings") > -1: + contain_frame_position = True + break + if contain_frame_position is False: + for key, val in state_dict.items(): + if key == "positional_embedding": + new_state_dict["frame_position_embeddings.weight"] = val.clone() + continue + if self.agg_module in ["seqTransf"] and key.find("transformer.resblocks") == 0: + num_layer = int(key.split(".")[2]) + # cut from beginning + if num_layer < config.num_hidden_layers: + new_state_dict[key.replace("transformer.", "transformerClip.")] = val.clone() + continue + + self.load_state_dict(new_state_dict, strict=False) # only update new state (seqTransf/seqLSTM/tightTransf) + ## <=== End of initialization trick + + for param in self.clip.parameters(): + param.requires_grad = False # not update by gradient + for param in self.transformerClip.parameters(): + param.requires_grad = False # not update by gradient + for param in self.frame_position_embeddings.parameters(): + param.requires_grad = False # not update by gradient + + for param in self.text_weight_fc.parameters(): + param.requires_grad = False # not update by gradient + for param in self.video_weight_fc.parameters(): + param.requires_grad = False # not update by gradient + + + def forward(self, text_ids, text_mask, video, video_mask=None, idx=None, global_step=0): + text_ids = text_ids.view(-1, text_ids.shape[-1]) + text_mask = text_mask.view(-1, text_mask.shape[-1]) + video_mask = video_mask.view(-1, video_mask.shape[-1]) + # B x N_v x 3 x H x W - > (B x N_v) x 3 x H x W + video = torch.as_tensor(video).float() + if len(video.size()) == 5: + b, n_v, d, h, w = video.shape + video = video.view(b * n_v, d, h, w) + else: + b, pair, bs, ts, channel, h, w = video.shape + video = video.view(b * pair * bs * ts, channel, h, w) + + text_feat, video_feat, cls = self.get_text_video_feat(text_ids, text_mask, video, video_mask, shaped=True) + + if self.training: + if torch.cuda.is_available(): # batch merge here + idx = allgather(idx, self.config) + text_feat = allgather(text_feat, self.config) + video_feat = allgather(video_feat, self.config) + text_mask = allgather(text_mask, self.config) + video_mask = allgather(video_mask, self.config) + cls = allgather(cls, self.config) + torch.distributed.barrier() # force sync + + idx = idx.view(-1, 1) + idx_all = idx.t() + pos_idx = torch.eq(idx, idx_all).float() + sim_targets = pos_idx / pos_idx.sum(1, keepdim=True) + logit_scale = self.clip.logit_scale.exp() + loss = 0. + + # entity level + logits, text_weight, video_weight = self.entity_level(text_feat, cls, video_feat, + text_mask, video_mask) + logits = torch.diagonal(logits, dim1=0, dim2=1).permute(2, 0, 1) + y = self.banzhafinteraction(logits.clone().detach(), text_mask, video_mask, text_weight, video_weight).detach() + p = self.banzhafteacher(logits.unsqueeze(1)).squeeze(1) + + p = torch.einsum('btv,bt->btv', [p, text_mask]) + p = torch.einsum('btv,bv->btv', [p, video_mask]) + + loss += self.mse(y, p) + + return loss + else: + return None + + def entity_level(self, text_feat, cls, video_feat, text_mask, video_mask): + if self.config.interaction == 'wti': + text_weight = self.text_weight_fc(text_feat).squeeze(2) # B x N_t x D -> B x N_t + text_weight.masked_fill_((1 - text_mask).to(torch.bool), float(-9e15)) + text_weight = torch.softmax(text_weight, dim=-1) # B x N_t + + video_weight = self.video_weight_fc(video_feat).squeeze(2) # B x N_v x D -> B x N_v + video_weight.masked_fill_((1 - video_mask).to(torch.bool), float(-9e15)) + video_weight = torch.softmax(video_weight, dim=-1) # B x N_v + + text_feat = text_feat / text_feat.norm(dim=-1, keepdim=True) + video_feat = video_feat / video_feat.norm(dim=-1, keepdim=True) + + retrieve_logits = torch.einsum('atd,bvd->abtv', [text_feat, video_feat]) + retrieve_logits = torch.einsum('abtv,at->abtv', [retrieve_logits, text_mask]) + retrieve_logits = torch.einsum('abtv,bv->abtv', [retrieve_logits, video_mask]) + + text_sum = text_mask.sum(-1) + video_sum = video_mask.sum(-1) + + if self.config.interaction == 'wti': # weighted token-wise interaction + t2v_logits, max_idx1 = retrieve_logits.max(dim=-1) # abtv -> abt + t2v_logits = torch.einsum('abt,at->ab', [t2v_logits, text_weight]) + + v2t_logits, max_idx2 = retrieve_logits.max(dim=-2) # abtv -> abv + v2t_logits = torch.einsum('abv,bv->ab', [v2t_logits, video_weight]) + + _retrieve_logits = (t2v_logits + v2t_logits) / 2.0 + else: + # max for video token + t2v_logits, max_idx1 = retrieve_logits.max(dim=-1) # abtv -> abt + v2t_logits, max_idx2 = retrieve_logits.max(dim=-2) # abtv -> abv + t2v_logits = torch.sum(t2v_logits, dim=2) / (text_sum.unsqueeze(1)) + v2t_logits = torch.sum(v2t_logits, dim=2) / (video_sum.unsqueeze(0)) + _retrieve_logits = (t2v_logits + v2t_logits) / 2.0 + + return retrieve_logits, text_weight, video_weight + + def get_text_feat(self, text_ids, text_mask, shaped=False): + if shaped is False: + text_ids = text_ids.view(-1, text_ids.shape[-1]) + text_mask = text_mask.view(-1, text_mask.shape[-1]) + + bs_pair = text_ids.size(0) + cls, text_feat = self.clip.encode_text(text_ids, return_hidden=True, mask=text_mask) + cls, text_feat = cls.float(), text_feat.float() + text_feat = text_feat.view(bs_pair, -1, text_feat.size(-1)) + cls = cls.view(bs_pair, -1, cls.size(-1)).squeeze(1) + return text_feat, cls + + def get_video_feat(self, video, video_mask, shaped=False): + if shaped is False: + video_mask = video_mask.view(-1, video_mask.shape[-1]) + video = torch.as_tensor(video).float() + if len(video.size()) == 5: + b, n_v, d, h, w = video.shape + video = video.view(b * n_v, d, h, w) + else: + b, pair, bs, ts, channel, h, w = video.shape + video = video.view(b * pair * bs * ts, channel, h, w) + + bs_pair, n_v = video_mask.size() + video_feat = self.clip.encode_image(video, return_hidden=True)[0].float() + video_feat = video_feat.float().view(bs_pair, -1, video_feat.size(-1)) + video_feat = self.agg_video_feat(video_feat, video_mask, self.agg_module) + return video_feat + + def get_text_video_feat(self, text_ids, text_mask, video, video_mask, shaped=False): + if shaped is False: + text_ids = text_ids.view(-1, text_ids.shape[-1]) + text_mask = text_mask.view(-1, text_mask.shape[-1]) + video_mask = video_mask.view(-1, video_mask.shape[-1]) + video = torch.as_tensor(video).float() + if len(video.shape) == 5: + b, n_v, d, h, w = video.shape + video = video.view(b * n_v, d, h, w) + else: + b, pair, bs, ts, channel, h, w = video.shape + video = video.view(b * pair * bs * ts, channel, h, w) + + text_feat, cls = self.get_text_feat(text_ids, text_mask, shaped=True) + video_feat = self.get_video_feat(video, video_mask, shaped=True) + + return text_feat, video_feat, cls + + def get_video_avg_feat(self, video_feat, video_mask): + video_mask_un = video_mask.to(dtype=torch.float).unsqueeze(-1) + video_feat = video_feat * video_mask_un + video_mask_un_sum = torch.sum(video_mask_un, dim=1, dtype=torch.float) + video_mask_un_sum[video_mask_un_sum == 0.] = 1. + video_feat = torch.sum(video_feat, dim=1) / video_mask_un_sum + return video_feat + + def get_text_sep_feat(self, text_feat, text_mask): + text_feat = text_feat.contiguous() + text_feat = text_feat[torch.arange(text_feat.shape[0]), torch.sum(text_mask, dim=-1) - 1, :] + text_feat = text_feat.unsqueeze(1).contiguous() + return text_feat + + def agg_video_feat(self, video_feat, video_mask, agg_module): + video_feat = video_feat.contiguous() + if agg_module == "None": + pass + elif agg_module == "seqLSTM": + # Sequential type: LSTM + video_feat_original = video_feat + video_feat = pack_padded_sequence(video_feat, torch.sum(video_mask, dim=-1).cpu(), + batch_first=True, enforce_sorted=False) + video_feat, _ = self.lstm_visual(video_feat) + if self.training: self.lstm_visual.flatten_parameters() + video_feat, _ = pad_packed_sequence(video_feat, batch_first=True) + video_feat = torch.cat( + (video_feat, video_feat_original[:, video_feat.size(1):, ...].contiguous()), dim=1) + video_feat = video_feat + video_feat_original + elif agg_module == "seqTransf": + # Sequential type: Transformer Encoder + video_feat_original = video_feat + seq_length = video_feat.size(1) + position_ids = torch.arange(seq_length, dtype=torch.long, device=video_feat.device) + position_ids = position_ids.unsqueeze(0).expand(video_feat.size(0), -1) + frame_position_embeddings = self.frame_position_embeddings(position_ids) + video_feat = video_feat + frame_position_embeddings + extended_video_mask = (1.0 - video_mask.unsqueeze(1)) * -1000000.0 + extended_video_mask = extended_video_mask.expand(-1, video_mask.size(1), -1) + video_feat = video_feat.permute(1, 0, 2) # NLD -> LND + video_feat = self.transformerClip(video_feat, extended_video_mask) + video_feat = video_feat.permute(1, 0, 2) # LND -> NLD + video_feat = video_feat + video_feat_original + return video_feat + + @property + def dtype(self): + """ + :obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). + """ + try: + return next(self.parameters()).dtype + except StopIteration: + # For nn.DataParallel compatibility in PyTorch 1.5 + def find_tensor_attributes(module: nn.Module): + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = self._named_members(get_members_fn=find_tensor_attributes) + first_tuple = next(gen) + return first_tuple[1].dtype + + def init_weights(self, module): + """ Initialize the weights. + """ + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=0.02) + elif isinstance(module, LayerNorm): + if 'beta' in dir(module) and 'gamma' in dir(module): + module.beta.data.zero_() + module.gamma.data.fill_(1.0) + else: + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() \ No newline at end of file diff --git a/README.md b/README.md index c36300f..f2c92a2 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,7 @@ If you find this paper useful, please consider staring 🌟 this repo and citing
## 📣 Updates +* Oct 11 2023: Release code for Banzhaf Interaction estimator. * Oct 08 2023: I am working on the code for Banzhaf Interaction estimator, which is expected to be released soon. * Jun 28 2023: Release code for reimplementing the experiments in the paper. * Mar 28 2023: Our **HBI** has been selected as a Highlight paper at CVPR 2023! (Top 2.5% of 9155 submissions). @@ -127,8 +128,40 @@ wget https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702 #### Train the Banzhaf Interaction Estimator -Train the estimator according to the label generated by the BanzhafInteraction in HBI/models/banzhaf.py. -Training code is under preparation... +Train the estimator according to the label generated by the BanzhafInteraction in HBI/models/banzhaf.py. The training code is provided in banzhaf_estimator.py. + +Recommended running parameters will be provided shortly, and we will also release our pre-trained estimator weights. + +