Skip to content

Commit

Permalink
Route MoE Promote (Post/Pre) update 0110
Browse files Browse the repository at this point in the history
  • Loading branch information
root committed Jan 10, 2024
1 parent ce67e56 commit ffa3298
Show file tree
Hide file tree
Showing 29 changed files with 6,761 additions and 195 deletions.
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: minigptv
name: promptmoe
channels:
- pytorch
- defaults
Expand Down
Binary file added lizrun
Binary file not shown.
16 changes: 8 additions & 8 deletions minigpt4/configs/datasets/coco/caption.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ datasets:
# md5: aa31ac474cf6250ebb81d18348a07ed8
storage:
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/COCO_Cap/coco_karpathy_train.json
val:
url: https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json
storage:
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/COCO_Cap/coco_karpathy_val.json
test:
url: https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json
storage:
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/COCO_Cap/coco_karpathy_test.json
# val:
# url: https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json
# storage:
# - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/COCO_Cap/coco_karpathy_val.json
# test:
# url: https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json
# storage:
# - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/COCO_Cap/coco_karpathy_test.json

images:
storage: /mnt/pfs-guan-ssai/nlu/dingyifeng/data/COCO
Expand Down
2 changes: 2 additions & 0 deletions minigpt4/configs/datasets/coco/defaults_vqa_eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ datasets:
- https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/v2_mscoco_val2014_annotations.json
storage:
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/vqa_val_eval.json
# - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/vqa_train.json
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/answer_list.json
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/v2_OpenEnded_mscoco_val2014_questions.json
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/v2_mscoco_val2014_annotations.json
Expand All @@ -29,6 +30,7 @@ datasets:
- https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/answer_list.json
storage:
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/vqa_test.json
# - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/vqa_train.json
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/answer_list.json

images:
Expand Down
2 changes: 2 additions & 0 deletions minigpt4/configs/datasets/okvqa/eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ datasets:
- https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/mscoco_val2014_annotations.json
storage:
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/okvqa_val_eval.json
# - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/okvqa_train.json
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/okvqa_answer_list_train.json
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/OpenEnded_mscoco_val2014_questions.json
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/mscoco_val2014_annotations.json
Expand All @@ -32,6 +33,7 @@ datasets:
- https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/mscoco_val2014_annotations.json
storage:
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/okvqa_val_eval.json
# - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/okvqa_train.json
# - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/okvqa_val_eval_part100.json
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/okvqa_answer_list_train.json
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/OpenEnded_mscoco_val2014_questions.json
Expand Down
15 changes: 12 additions & 3 deletions minigpt4/datasets/datasets/caption_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
'Using language, provide a short account of the image.',
'Use a few words to illustrate what is happening in the picture.',
]
self.source = 'coco_cap'

def __getitem__(self, index):

# TODO this assumes image input, not general enough
Expand All @@ -118,13 +120,20 @@ def __getitem__(self, index):
image = self.vis_processor(image)
caption = self.text_processor(ann["caption"])

instruction = random.choice(self.instruction_pool)
instruction = "<Img><ImageHere></Img> [caption] {} ".format(instruction)
# instruction = random.choice(self.instruction_pool)
# instruction = "<Img><ImageHere></Img> [caption] {} ".format(instruction)
q_input = ""
llm_input = random.choice(self.instruction_pool)

return {
"image": image,
"image_id": ann["image"],
"answer": caption,
"instruction_input": instruction,
"q_input": q_input,
"llm_input": llm_input,
"text_input": llm_input,
"text_output": caption,
"source": 'coco_cap',
}

class CaptionEvalDataset(BaseDataset, __DisplMixin):
Expand Down
1 change: 1 addition & 0 deletions minigpt4/datasets/datasets/coco_caption.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
split (string): val or test
"""
super().__init__(vis_processor, text_processor, vis_root, ann_paths)
self.source = 'coco_cap'

def __getitem__(self, index):
ann = self.annotation[index]
Expand Down
1 change: 0 additions & 1 deletion minigpt4/datasets/datasets/dataloader_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def __init__(self, loaders, ratios=None):
if ratios is None:
ratios = [1.0] * len(loaders)
else:
# import pdb; pdb.set_trace()
assert len(ratios) == len(loaders)
ratios = [float(ratio) / sum(ratios) for ratio in ratios]

Expand Down
36 changes: 22 additions & 14 deletions minigpt4/models/QformerMoE.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,17 +386,23 @@ def forward(self, hidden_states, input_tensor):


class FeedForward(nn.Module):
# remove LayerNorm
def __init__(self, config):
nn.Module.__init__(self)
# first layer
self.intermediate_query = BertIntermediate(config)
# second layer
self.output_query = BertOutput(config)
super().__init__()
self.dense1 = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
self.dense2 = nn.Linear(config.intermediate_size, config.hidden_size)
# self.dropout = nn.Dropout(config.hidden_dropout_prob) # adjust dropout ratio 0.1->0.2
self.dropout = nn.Dropout(0.2) # adjust dropout ratio 0.1->0.2

def forward(self, hidden_states: Tensor):
input_tensor = hidden_states
intermediate_output = self.intermediate_query(hidden_states)
hidden_states = self.output_query(intermediate_output, input_tensor)
hidden_states = self.dense1(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
hidden_states = self.dense2(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states


Expand Down Expand Up @@ -440,6 +446,7 @@ def __init__(self, config, layer_num):
)
else:
self.experts = ffn
self.expert_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

def forward(
self,
Expand Down Expand Up @@ -494,7 +501,8 @@ def forward(
moe_ffn_attention_input = query_attention_output[:, :query_length, :]
moe_ffn_attention_mask = attention_mask.squeeze(dim=1).squeeze(dim=1)[:, :query_length]
layer_output = self.feed_forward_query_moe(moe_ffn_attention_input, moe_ffn_attention_mask) # layer_output, gate_loss, gate_load

# import pdb; pdb.set_trace() # test0107

if attention_output.shape[1] > query_length: # have text input in Qformer
layer_output_text = apply_chunking_to_forward(
self.feed_forward_chunk,
Expand All @@ -503,6 +511,7 @@ def forward(
attention_output[:, query_length:, :],
)
layer_output = (torch.cat([layer_output[0], layer_output_text], dim=1), layer_output[1], layer_output[2])

else:
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk,
Expand All @@ -524,15 +533,14 @@ def feed_forward_chunk(self, attention_output):

def feed_forward_query_moe(self, attention_output, expert_attention_mask):
if not self.use_experts:
layer_output = self.experts(attention_output)
hidden_states = self.experts(attention_output)
layer_output = self.expert_ln(hidden_states + attention_output)
return layer_output, 0.0, []

# if not self.importance_processor.is_moe:
# raise RuntimeError("Need to turn the model to a MoE first.")

layer_output, gate_loss, gate_load = self.experts(
hidden_states, gate_loss, gate_load = self.experts(
attention_output, expert_attention_mask
)
layer_output = self.expert_ln(hidden_states + attention_output)
return layer_output, gate_loss, gate_load

class BertEncoder(nn.Module):
Expand Down
50 changes: 30 additions & 20 deletions minigpt4/models/QformerRouteMoE.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,9 @@
from transformers.models.bert.configuration_bert import BertConfig

from minigpt4.models.moe.utils import (
FeedForward,
MoEModelOutput,
MoEModelOutputWithPooling,
use_experts,
use_experts_route,
moe_layer_judge,
)
from minigpt4.models.moe.route_moe_layer import RouteMoELayer
Expand Down Expand Up @@ -378,13 +377,14 @@ class BertOutput(nn.Module): # Add & Norm
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) # 1
self.dropout = nn.Dropout(config.hidden_dropout_prob)

def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
# Move LayerNorm & ResNet out of FFN After MoEFFN
hidden_states = self.LayerNorm(hidden_states + input_tensor) # 1
return hidden_states


Expand Down Expand Up @@ -429,7 +429,7 @@ def __init__(self, config, layer_num):
self.output_query = BertOutput(config)

# Add MoE FFN
self.use_experts = use_experts(layer_num)
self.use_experts = use_experts_route(layer_num)
self.layer_judge = moe_layer_judge(layer_num)
self.num_beams = config.moebert_num_beams
ffn = FeedForward(config)
Expand All @@ -442,10 +442,13 @@ def __init__(self, config, layer_num):
num_beams=config.moebert_num_beams,
layer_judge = self.layer_judge,
route_method=config.route_method,
weight_type=config.moe_weight_type,
)
else:
self.experts = ffn

# self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

def forward(
self,
hidden_states,
Expand All @@ -463,8 +466,8 @@ def forward(
self_attn_past_key_value = (
past_key_value[:2] if past_key_value is not None else None
)
# import pdb;pdb.set_trace()

# import pdb; pdb.set_trace() # 0107test
# adjust the dimension of hidden_states, attention_mask, encoder_attention_mask and encoder_hidden_states to be the same
if self.num_beams > 1:
if hidden_states.shape[0]== attention_mask.shape[0]*self.num_beams:
Expand Down Expand Up @@ -494,10 +497,6 @@ def forward(

present_key_value = self_attention_outputs[-1]

# import pdb;pdb.set_trace()
# print(self.layer_num, hidden_states.shape, attention_mask.shape)


if query_length > 0:
query_attention_output = attention_output[:, :query_length, :]

Expand Down Expand Up @@ -526,7 +525,8 @@ def forward(
moe_ffn_attention_input = query_attention_output[:, :query_length, :]
moe_ffn_attention_mask = attention_mask.squeeze(dim=1).squeeze(dim=1)[:, :query_length]
layer_output = self.feed_forward_query_moe(moe_ffn_attention_input, moe_ffn_attention_mask, beam_scores, expert_route)
# layer_output = (layer_output, beam_scores, expert_route, beam_idx)
# layer_output = (layer_output, beam_scores, expert_route, beam_idx, importance_loss)
# import pdb; pdb.set_trace() # 0107test

if attention_output.shape[1] > query_length: # have text input in Qformer
layer_output_text = apply_chunking_to_forward(
Expand All @@ -535,7 +535,8 @@ def forward(
self.seq_len_dim,
attention_output[:, query_length:, :],
)
if layer_output[0].shape[0] == layer_output_text.shape[0]*self.num_beams and self.num_beams>1:
if self.layer_judge == 'first' and self.num_beams>1:
# if layer_output[0].shape[0] == layer_output_text.shape[0]*self.num_beams and self.num_beams>1:
# adjust the dimension of layer_output_text to bz*num_beams
layer_output_text = self.adjust_layer_output_text(layer_output_text)

Expand All @@ -550,7 +551,9 @@ def forward(
# layer_output & layer_output_text dimen_0 from bz*num_beams to bz
layer_output, layer_output_text = self.route_moe_last_layer_top1(layer_output, layer_output_text)

layer_output = (torch.cat([layer_output[0], layer_output_text], dim=1), layer_output[1], layer_output[2])
# import pdb; pdb.set_trace() # 0107test

layer_output = (torch.cat([layer_output[0], layer_output_text], dim=1), layer_output[1], layer_output[2], layer_output[3],layer_output[4])

else:
layer_output = apply_chunking_to_forward(
Expand All @@ -559,7 +562,7 @@ def forward(
self.seq_len_dim,
attention_output,
)
layer_output = (layer_output, None, None)
layer_output = (layer_output, None, None, None, 0.0)

outputs = (layer_output,) + outputs

Expand Down Expand Up @@ -594,24 +597,27 @@ def route_moe_last_layer_top1(self, layer_output, layer_output_text):
beam_scores_new = beam_scores[selects]
expert_route_new = expert_route[selects]

return (hidden_states_new, beam_scores_new, expert_route_new), layer_output_text
return (hidden_states_new, beam_scores_new, expert_route_new, layer_output[3], layer_output[4]), layer_output_text


def feed_forward_chunk(self, attention_output):
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
# layer_output = self.LayerNorm(layer_output + attention_output)
return layer_output

def feed_forward_query_moe(self, attention_output, expert_attention_mask, beam_scores, expert_route):

if not self.use_experts:
layer_output = self.experts(attention_output)
return layer_output, None, None, None
# layer_output = self.LayerNorm(layer_output + attention_output)
return layer_output, None, None, None, 0.0

layer_output, beam_scores, expert_route, beam_idx = self.experts(
layer_output, beam_scores, expert_route, beam_idx, importance_loss = self.experts(
attention_output, expert_attention_mask, beam_scores, expert_route
)
return layer_output, beam_scores, expert_route, beam_idx

# layer_output = self.LayerNorm(layer_output + attention_output)
return layer_output, beam_scores, expert_route, beam_idx, importance_loss

class BertEncoder(nn.Module):
def __init__(self, config):
Expand Down Expand Up @@ -645,6 +651,7 @@ def forward(
next_decoder_cache = () if use_cache else None
beam_scores=None
expert_route=None
importance_loss = 0
for i in range(self.config.num_hidden_layers):

layer_module = self.layer[i]
Expand Down Expand Up @@ -693,6 +700,7 @@ def custom_forward(*inputs):
hidden_states = layer_outputs[0][0]
beam_scores = beam_scores if layer_outputs[0][1] == None else layer_outputs[0][1]
expert_route = expert_route if layer_outputs[0][2] == None else layer_outputs[0][2]
importance_loss += layer_outputs[0][4]

if use_cache:
next_decoder_cache += (layer_outputs[-1],)
Expand Down Expand Up @@ -724,6 +732,7 @@ def custom_forward(*inputs):
cross_attentions=all_cross_attentions,
beam_scores=beam_scores,
expert_route=expert_route,
gate_loss=importance_loss,
)


Expand Down Expand Up @@ -1103,6 +1112,7 @@ def forward(
cross_attentions=encoder_outputs.cross_attentions,
beam_scores=encoder_outputs.beam_scores,
expert_route=encoder_outputs.expert_route,
gate_loss=encoder_outputs.gate_loss
)


Expand Down
3 changes: 2 additions & 1 deletion minigpt4/models/blip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2):
return Qformer, query_tokens

@classmethod
def init_RouteMoEQformer(cls, num_query_token, vision_width, moebert_expert_num, moebert_num_beams, route_method, cross_attention_freq=2):
def init_RouteMoEQformer(cls, num_query_token, vision_width, moebert_expert_num, moebert_num_beams, route_method, moe_weight_type, cross_attention_freq=2):
moe_encoder_config = BertConfig.from_pretrained("/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased")

moe_encoder_config.encoder_width = vision_width
Expand All @@ -74,6 +74,7 @@ def init_RouteMoEQformer(cls, num_query_token, vision_width, moebert_expert_num,
moe_encoder_config.moebert_expert_num = moebert_expert_num
moe_encoder_config.moebert_num_beams = moebert_num_beams
moe_encoder_config.route_method = route_method
moe_encoder_config.moe_weight_type = moe_weight_type

RouteMoEQformer = BertMoERouteLMHeadModel.from_pretrained(
"/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased", config=moe_encoder_config
Expand Down
Loading

0 comments on commit ffa3298

Please sign in to comment.