-
Notifications
You must be signed in to change notification settings - Fork 2.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update QformerMoE & DataProcess 1123
- Loading branch information
Showing
86 changed files
with
9,102 additions
and
225 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
""" | ||
Copyright (c) 2022, salesforce.com, inc. | ||
All rights reserved. | ||
SPDX-License-Identifier: BSD-3-Clause | ||
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | ||
""" | ||
|
||
import argparse | ||
import random | ||
|
||
import numpy as np | ||
import torch | ||
import torch.backends.cudnn as cudnn | ||
|
||
import minigpt4.tasks as tasks | ||
from minigpt4.common.config import Config | ||
from minigpt4.common.dist_utils import get_rank, init_distributed_mode | ||
from minigpt4.common.logger import setup_logger | ||
from minigpt4.common.optims import ( | ||
LinearWarmupCosineLRScheduler, | ||
LinearWarmupStepLRScheduler, | ||
) | ||
from minigpt4.common.utils import now | ||
|
||
# imports modules for registration | ||
from minigpt4.datasets.builders import * | ||
from minigpt4.models import * | ||
from minigpt4.processors import * | ||
from minigpt4.runners.runner_base import RunnerBase | ||
from minigpt4.tasks import * | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser(description="Training") | ||
|
||
parser.add_argument("--cfg-path", required=True, help="path to configuration file.") | ||
parser.add_argument( | ||
"--options", | ||
nargs="+", | ||
help="override some settings in the used config, the key-value pair " | ||
"in xxx=yyy format will be merged into config file (deprecate), " | ||
"change to --cfg-options instead.", | ||
) | ||
|
||
args = parser.parse_args() | ||
# if 'LOCAL_RANK' not in os.environ: | ||
# os.environ['LOCAL_RANK'] = str(args.local_rank) | ||
|
||
return args | ||
|
||
|
||
def setup_seeds(config): | ||
seed = config.run_cfg.seed + get_rank() | ||
|
||
random.seed(seed) | ||
np.random.seed(seed) | ||
torch.manual_seed(seed) | ||
|
||
cudnn.benchmark = False | ||
cudnn.deterministic = True | ||
|
||
|
||
def main(): | ||
# allow auto-dl completes on main process without timeout when using NCCL backend. | ||
# os.environ["NCCL_BLOCKING_WAIT"] = "1" | ||
|
||
# set before init_distributed_mode() to ensure the same job_id shared across all ranks. | ||
job_id = now() | ||
|
||
cfg = Config(parse_args()) | ||
|
||
init_distributed_mode(cfg.run_cfg) | ||
|
||
setup_seeds(cfg) | ||
|
||
# set after init_distributed_mode() to only log on master. | ||
setup_logger() | ||
|
||
cfg.pretty_print() | ||
|
||
task = tasks.setup_task(cfg) | ||
datasets = task.build_datasets(cfg) | ||
model = task.build_model(cfg) | ||
|
||
runner = RunnerBase( | ||
cfg=cfg, job_id=job_id, task=task, model=model, datasets=datasets | ||
) | ||
runner.evaluate(skip_reload=True) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import argparse | ||
import numpy as np | ||
from nltk.translate.bleu_score import sentence_bleu | ||
|
||
from minigpt4.common.registry import registry | ||
from minigpt4.common.config import Config | ||
|
||
# imports modules for registration | ||
from minigpt4.datasets.builders import * | ||
from minigpt4.models import * | ||
from minigpt4.processors import * | ||
from minigpt4.runners import * | ||
from minigpt4.tasks import * | ||
|
||
|
||
|
||
def eval_parser(): | ||
parser = argparse.ArgumentParser(description="Demo") | ||
parser.add_argument("--cfg-path", required=True, help="path to configuration file.") | ||
parser.add_argument("--name", type=str, default='A2', help="evaluation name") | ||
parser.add_argument("--ckpt", type=str, help="path to configuration file.") | ||
parser.add_argument("--eval_opt", type=str, default='all', help="path to configuration file.") | ||
parser.add_argument("--max_new_tokens", type=int, default=10, help="max number of generated tokens") | ||
parser.add_argument("--batch_size", type=int, default=32) | ||
parser.add_argument("--lora_r", type=int, default=64, help="lora rank of the model") | ||
parser.add_argument("--lora_alpha", type=int, default=16, help="lora alpha") | ||
parser.add_argument( | ||
"--options", | ||
nargs="+", | ||
help="override some settings in the used config, the key-value pair " | ||
"in xxx=yyy format will be merged into config file (deprecate), " | ||
"change to --cfg-options instead.", | ||
) | ||
return parser | ||
|
||
|
||
def prepare_texts(texts, conv_temp): | ||
convs = [conv_temp.copy() for _ in range(len(texts))] | ||
[conv.append_message( | ||
conv.roles[0], '<Img><ImageHere></Img> {}'.format(text)) for conv, text in zip(convs, texts)] | ||
[conv.append_message(conv.roles[1], None) for conv in convs] | ||
texts = [conv.get_prompt() for conv in convs] | ||
return texts | ||
|
||
|
||
def init_model(args): | ||
print('Initialization Model') | ||
cfg = Config(args) | ||
# cfg.model_cfg.ckpt = args.ckpt | ||
# cfg.model_cfg.lora_r = args.lora_r | ||
# cfg.model_cfg.lora_alpha = args.lora_alpha | ||
|
||
model_config = cfg.model_cfg | ||
model_cls = registry.get_model_class(model_config.arch) | ||
model = model_cls.from_config(model_config).to('cuda:0') | ||
|
||
# import pudb; pudb.set_trace() | ||
key = list(cfg.datasets_cfg.keys())[0] | ||
vis_processor_cfg = cfg.datasets_cfg.get(key).vis_processor.train | ||
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg) | ||
print('Initialization Finished') | ||
return model, vis_processor | ||
|
||
def computeIoU(bbox1, bbox2): | ||
x1, y1, x2, y2 = bbox1 | ||
x3, y3, x4, y4 = bbox2 | ||
intersection_x1 = max(x1, x3) | ||
intersection_y1 = max(y1, y3) | ||
intersection_x2 = min(x2, x4) | ||
intersection_y2 = min(y2, y4) | ||
intersection_area = max(0, intersection_x2 - intersection_x1 + 1) * max(0, intersection_y2 - intersection_y1 + 1) | ||
bbox1_area = (x2 - x1 + 1) * (y2 - y1 + 1) | ||
bbox2_area = (x4 - x3 + 1) * (y4 - y3 + 1) | ||
union_area = bbox1_area + bbox2_area - intersection_area | ||
iou = intersection_area / union_area | ||
return iou |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
""" | ||
Copyright (c) 2022, salesforce.com, inc. | ||
All rights reserved. | ||
SPDX-License-Identifier: BSD-3-Clause | ||
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | ||
""" | ||
|
||
__author__ = "aagrawal" |
Oops, something went wrong.