diff --git a/flagscale/train/models/llava_onevision/dataloader_provider.py b/flagscale/train/models/llava_onevision/dataloader_provider.py index 9ba67ceca..48b915342 100644 --- a/flagscale/train/models/llava_onevision/dataloader_provider.py +++ b/flagscale/train/models/llava_onevision/dataloader_provider.py @@ -67,11 +67,16 @@ def datasets_provider(worker_config=None): def train_valid_test_dataloaders_provider(train_val_test_num_samples): """Build multimodal train, validation and test dataloaders.""" + args = get_args() + + # In llava-ov, set skip_train False to eval each sample. + # Training while evaluating is not supported yet. + if args.skip_train: + args.eval_iters = args.train_iters + if get_tensor_model_parallel_rank() != 0: return None, None, None - args = get_args() - worker_debug_path = None worker_log_level = 0 @@ -110,11 +115,18 @@ def train_valid_test_dataloaders_provider(train_val_test_num_samples): "loading dataloader checkpoint failed. Skipping. " + str(e) ) if args.training_dataset_only: - return ( - EnergonDataloader(train_dataloader), - EnergonDataloader(None), - EnergonDataloader(None), - ) + if not args.skip_train: + return ( + EnergonDataloader(train_dataloader), + None, + None, + ) + else: + return ( + None, + EnergonDataloader(train_dataloader), + None, + ) valid_dataloader = [ EnergonDataloader(get_loader(valid_ds, worker_config=worker_config)) for valid_ds in valid_ds1 diff --git a/flagscale/train/models/llava_onevision/dataset_helpers.py b/flagscale/train/models/llava_onevision/dataset_helpers.py index 0ed334953..1f0650a65 100644 --- a/flagscale/train/models/llava_onevision/dataset_helpers.py +++ b/flagscale/train/models/llava_onevision/dataset_helpers.py @@ -36,6 +36,8 @@ class AnyResTaskSample: images: List[torch.Tensor] image_sizes: List[torch.Tensor] modalities: List[torch.Tensor] + ids: torch.Tensor + ids_shape: torch.Tensor # Typing for the resulting batch data after encode_batch() @dataclass @@ -50,6 +52,8 @@ class AnyResTaskBatch(Batch): image_sizes: torch.Tensor split_image_sizes: torch.Tensor modalities: torch.Tensor + ids: torch.Tensor + ids_shape: torch.Tensor class AnyResTaskEncoder(DefaultTaskEncoder[InterleavedSample, InterleavedSample, AnyResTaskBatch, dict]): @@ -84,6 +88,10 @@ def encode_interleaved(self, sample: InterleavedSample): else: assert ValueError("The sequence must have 4 or 5 elements, but got {len(sample.sequence)}.") + id = "".join(sample.__key__.split("/")[1:]) + ids_tensor = torch.tensor([ord(c) for c in id], dtype=torch.uint8) + ids_shape = torch.tensor(ids_tensor.shape) + # process modalities to tensor modalities_list = [] for modality in modalities: @@ -107,7 +115,9 @@ def encode_interleaved(self, sample: InterleavedSample): labels_shape=torch.tensor(labels.shape), images=images, image_sizes=image_sizes, - modalities=modalities + modalities=modalities, + ids=ids_tensor, + ids_shape=ids_shape ) def batch(self, samples: List[AnyResTaskSample]) -> AnyResTaskBatch: @@ -121,7 +131,8 @@ def batch(self, samples: List[AnyResTaskSample]) -> AnyResTaskBatch: # Adapt video data by decord image_sizes = torch.stack([image_sizes if len(image_sizes.shape) == 1 else torch.tensor((1, image_sizes.item())) for s in samples for image_sizes in s.image_sizes], dim=0) modalities = torch.stack([modalities for s in samples for modalities in s.modalities], dim=0) - + ids = torch.cat([s.ids.flatten() for s in samples], dim=0) + ids_shape = torch.stack([s.ids_shape for s in samples], dim=0) batch = AnyResTaskBatch( __keys__=[s.__key__ for s in samples], __subflavors__=[s.__subflavors__ for s in samples], @@ -132,7 +143,9 @@ def batch(self, samples: List[AnyResTaskSample]) -> AnyResTaskBatch: images=images, image_sizes=image_sizes, split_image_sizes=split_image_sizes, - modalities=modalities + modalities=modalities, + ids=ids, + ids_shape=ids_shape, ) return batch diff --git a/flagscale/train/train_llava_onevision.py b/flagscale/train/train_llava_onevision.py index 36386a278..7e1a87295 100644 --- a/flagscale/train/train_llava_onevision.py +++ b/flagscale/train/train_llava_onevision.py @@ -191,6 +191,10 @@ def get_batch(data_iterator): labels_shape = tensor_parallel.broadcast_data(["labels_shape"], data, torch.int64)[ "labels_shape" ] + ids = tensor_parallel.broadcast_data(["ids"], data, torch.uint8)["ids"] + ids_shape = tensor_parallel.broadcast_data(["ids_shape"], data, torch.int64)[ + "ids_shape" + ] images = tensor_parallel.broadcast_data(["images"], data, torch.float32)["images"] split_image_sizes = tensor_parallel.broadcast_data( ["split_image_sizes"], data, torch.int64 @@ -229,6 +233,17 @@ def get_batch(data_iterator): assert start_idx == labels.numel() labels = labels_list + # ids to list + ids_list = [] + start_idx = 0 + for shape in ids_shape: + num_elements = torch.prod(shape).item() + sub_tensor = ids[start_idx : start_idx + num_elements].reshape(shape.tolist()) + ids_list.append(sub_tensor) + start_idx += num_elements + assert start_idx == ids.numel() + ids = ids_list + # images to list images_list = [] start_idx = 0 @@ -288,7 +303,7 @@ def get_batch(data_iterator): attention_mask = input_ids.ne(tokenizer.pad_token_id) torch.cuda.nvtx.range_pop() - return input_ids, labels, attention_mask, images, image_sizes, modalities + return input_ids, labels, attention_mask, images, image_sizes, modalities, ids def pad_sequence(input_ids, batch_first, padding_value, tokenizer): @@ -316,7 +331,13 @@ def get_image_token_count(): return num_image_tokens -def loss_func(labels: torch.Tensor, loss_mask: torch.Tensor, logits: torch.Tensor): +def loss_func( + labels: torch.Tensor, + loss_mask: torch.Tensor, + ids, + logits: torch.Tensor, +): + args = get_args() labels = labels.transpose(0, 1).contiguous() # [b s] => [s b] logits = logits.transpose(0, 1).contiguous() # [b s h] => [s b h] @@ -334,6 +355,11 @@ def loss_func(labels: torch.Tensor, loss_mask: torch.Tensor, logits: torch.Tenso loss = torch.mean(losses) # Reduce loss for logging. + if args.skip_train: + assert isinstance(ids, list) and len(ids) == 1 + id = "".join([chr(c) for c in ids[0].cpu().numpy()]) + print(f"Evaluating id: {id}, loss: {loss.detach().clone().item()}", flush=True) + averaged_loss = average_losses_across_data_parallel_group([loss]) return loss, {"lm loss": averaged_loss[0]} @@ -354,7 +380,7 @@ def forward_step(data_iterator, model: LLaVAOneVisionModel): # Get the batch. timers("batch-generator", log_level=2).start() - input_ids, labels, attention_mask, images, image_sizes, modalities = get_batch( + input_ids, labels, attention_mask, images, image_sizes, modalities, ids = get_batch( data_iterator ) if "text" in modalities and ("image" in modalities or "video" in modalities): @@ -367,7 +393,7 @@ def forward_step(data_iterator, model: LLaVAOneVisionModel): input_ids, labels, attention_mask, images, image_sizes, modalities ) - return output_tensor, partial(loss_func, labels, loss_mask) + return output_tensor, partial(loss_func, labels, loss_mask, ids) def add_multimodal_extra_args(parser): diff --git a/tools/datasets/llava_onevision/filter.py b/tools/datasets/llava_onevision/filter.py new file mode 100644 index 000000000..99f6a5451 --- /dev/null +++ b/tools/datasets/llava_onevision/filter.py @@ -0,0 +1,43 @@ +import os +import re +import json +import argparse +from typing import Dict + + +def main(): + parser = argparse.ArgumentParser(description='Grep id and loss from log files.') + parser.add_argument('--input_dir', type=str, help='Directory to search log files.') + parser.add_argument('--output', type=str, help='Path to save the result.') + args = parser.parse_args() + + result_dict: Dict[str, float] = {} + for root, dirs, files in os.walk(args.input_dir): + for file in files: + if file.endswith('.log'): + file_path = os.path.join(root, file) + with open(file_path, 'r') as f: + lines = f.readlines() + for line in lines: + match = re.search(r'Evaluating id: (\d+), loss: ([\d.]+)', line) + if match: + evaluating_id = match.group(1) + loss = float(match.group(2)) + if evaluating_id in result_dict: + assert loss == result_dict[evaluating_id] + # Customize filtering rules such as + # if loss < 0.5: + # result_dict[evaluating_id] = loss + + # NOTE: No filtering currently, Comment out if Customize + result_dict[evaluating_id] = loss + + result = {"ids": list(result_dict.keys())} + assert args.output.endswith(".json") + with open(args.output, 'w') as f: + json.dump(result, f, indent=4) + print("Done") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tools/datasets/llava_onevision/llava_ov_wds.py b/tools/datasets/llava_onevision/llava_ov_wds.py index bf2e9cc27..ca831931c 100755 --- a/tools/datasets/llava_onevision/llava_ov_wds.py +++ b/tools/datasets/llava_onevision/llava_ov_wds.py @@ -1432,6 +1432,8 @@ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: # batch["images"] = torch.stack(images) # else: batch["images"] = images + assert "id" in instances[0] + batch["ids"] = [torch.tensor([ord(c) for c in instance["id"]], dtype=torch.uint8) for instance in instances] if "prompt" in instances[0]: batch["prompts"] = [instance["prompt"] for instance in instances] @@ -1847,15 +1849,22 @@ def make_inputs_require_grad(module, input, output): if not os.path.exists(output): os.mkdir(output) start_time = time.time() + filter_ids = [] + filter_json = os.environ.get("FILTER_JSON", "") + if filter_json: + with open(filter_json, 'r') as file: + data = json.load(file) + filter_ids = data["ids"] with wds.ShardWriter(os.path.join(output, f'llava-ov-{dist.get_rank()}-%d.tar'), maxcount=10000) as shard_writer: dataloader = trainer.get_train_dataloader() print(f"sample num: {len(dataloader)}") - global_id = 0 for entry in tqdm(dataloader): - if global_id == 0: - for x in entry.keys(): - #print(f"key={x}, type={type(entry[x])}") - pass + assert 'ids' in entry + assert len(entry["ids"]) == 1 + id = "".join([chr(c) for c in entry["ids"][0].cpu().numpy()]) + if filter_ids and id not in filter_ids: + print(f"The id {id} is filtered out.") + continue sequence = [] sequence.append(entry['input_ids'][0].cpu()) @@ -1878,15 +1887,16 @@ def make_inputs_require_grad(module, input, output): sequence.append([torch.tensor(entry['image_sizes'][0])]) sequence.append([entry['modalities'][0]]) if entry['modalities'][0] == "video": - print(f"Processing video and image_sizes: {entry['image_sizes'][0]}, {images.shape}") + print(f"Processing id {id} video and image_sizes: {entry['image_sizes'][0]}, {images.shape}") elif entry['modalities'][0] == "text": - print("Processing text.") + print(f"Processing id {id} text.") elif entry['modalities'][0] == "image": - print("Processing single image.") + print(f"Processing id {id} single image.") else: raise ValueError() else: # Process images + print(f"Processing id {id} multi images.") images = [] each_image_shape = None for image in entry['images']: @@ -1912,15 +1922,14 @@ def make_inputs_require_grad(module, input, output): sequence.append(images) sequence.append(image_sizes) - sequence.append(modalities) + sequence.append(modalities) sample = { - "__key__": str(global_id), + "__key__": str(id), "sequence.pyd": sequence, } shard_writer.write(sample) - global_id += 1 print(f"rank {dist.get_rank()} datasets saved to {training_args.output_dir}") diff --git a/tools/datasets/llava_onevision/make_llava_ov_wds.sh b/tools/datasets/llava_onevision/make_llava_ov_wds.sh index bb6a25c90..2b70edc90 100644 --- a/tools/datasets/llava_onevision/make_llava_ov_wds.sh +++ b/tools/datasets/llava_onevision/make_llava_ov_wds.sh @@ -27,6 +27,12 @@ set -u HOSTFILE=$3 set +u +if [ $# -ge 4 ]; then + FILTER_JSON=$4 +else + FILTER_JSON="" +fi + echo "BASE_RUN_NAME: ${EXPNAME_PATH}" CKPT_PATH="./checkpoints" @@ -52,6 +58,7 @@ do export WANDB_MODE=offline && \ export ACCELERATE_CPU_AFFINITY=1 && \ export PYTHONPATH=$LLaVA_NeXT_HOME:$PYTHONPATH && \ + export FILTER_JSON=$FILTER_JSON && \ source /root/miniconda3/bin/activate flagscale && \ torchrun --nproc_per_node=8 --nnodes=${NNodes} --node_rank=${rank} --master_addr=${MASTER_ADDR} --master_port=13888 llava_ov_wds.py \ --model_name_or_path ${CKPT_PATH} \