Skip to content

Commit

Permalink
llava ov eval and filter
Browse files Browse the repository at this point in the history
  • Loading branch information
caozhou committed Jan 3, 2025
1 parent 77e104c commit 35afa95
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 25 deletions.
26 changes: 19 additions & 7 deletions flagscale/train/models/llava_onevision/dataloader_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,16 @@ def datasets_provider(worker_config=None):

def train_valid_test_dataloaders_provider(train_val_test_num_samples):
"""Build multimodal train, validation and test dataloaders."""
args = get_args()

# In llava-ov, set skip_train False to eval each sample.
# Training while evaluating is not supported yet.
if args.skip_train:
args.eval_iters = args.train_iters

if get_tensor_model_parallel_rank() != 0:
return None, None, None

args = get_args()

worker_debug_path = None
worker_log_level = 0

Expand Down Expand Up @@ -110,11 +115,18 @@ def train_valid_test_dataloaders_provider(train_val_test_num_samples):
"loading dataloader checkpoint failed. Skipping. " + str(e)
)
if args.training_dataset_only:
return (
EnergonDataloader(train_dataloader),
EnergonDataloader(None),
EnergonDataloader(None),
)
if not args.skip_train:
return (
EnergonDataloader(train_dataloader),
None,
None,
)
else:
return (
None,
EnergonDataloader(train_dataloader),
None,
)
valid_dataloader = [
EnergonDataloader(get_loader(valid_ds, worker_config=worker_config))
for valid_ds in valid_ds1
Expand Down
19 changes: 16 additions & 3 deletions flagscale/train/models/llava_onevision/dataset_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class AnyResTaskSample:
images: List[torch.Tensor]
image_sizes: List[torch.Tensor]
modalities: List[torch.Tensor]
ids: torch.Tensor
ids_shape: torch.Tensor

# Typing for the resulting batch data after encode_batch()
@dataclass
Expand All @@ -50,6 +52,8 @@ class AnyResTaskBatch(Batch):
image_sizes: torch.Tensor
split_image_sizes: torch.Tensor
modalities: torch.Tensor
ids: torch.Tensor
ids_shape: torch.Tensor


class AnyResTaskEncoder(DefaultTaskEncoder[InterleavedSample, InterleavedSample, AnyResTaskBatch, dict]):
Expand Down Expand Up @@ -84,6 +88,10 @@ def encode_interleaved(self, sample: InterleavedSample):
else:
assert ValueError("The sequence must have 4 or 5 elements, but got {len(sample.sequence)}.")

id = "".join(sample.__key__.split("/")[1:])
ids_tensor = torch.tensor([ord(c) for c in id], dtype=torch.uint8)
ids_shape = torch.tensor(ids_tensor.shape)

# process modalities to tensor
modalities_list = []
for modality in modalities:
Expand All @@ -107,7 +115,9 @@ def encode_interleaved(self, sample: InterleavedSample):
labels_shape=torch.tensor(labels.shape),
images=images,
image_sizes=image_sizes,
modalities=modalities
modalities=modalities,
ids=ids_tensor,
ids_shape=ids_shape
)

def batch(self, samples: List[AnyResTaskSample]) -> AnyResTaskBatch:
Expand All @@ -121,7 +131,8 @@ def batch(self, samples: List[AnyResTaskSample]) -> AnyResTaskBatch:
# Adapt video data by decord
image_sizes = torch.stack([image_sizes if len(image_sizes.shape) == 1 else torch.tensor((1, image_sizes.item())) for s in samples for image_sizes in s.image_sizes], dim=0)
modalities = torch.stack([modalities for s in samples for modalities in s.modalities], dim=0)

ids = torch.cat([s.ids.flatten() for s in samples], dim=0)
ids_shape = torch.stack([s.ids_shape for s in samples], dim=0)
batch = AnyResTaskBatch(
__keys__=[s.__key__ for s in samples],
__subflavors__=[s.__subflavors__ for s in samples],
Expand All @@ -132,7 +143,9 @@ def batch(self, samples: List[AnyResTaskSample]) -> AnyResTaskBatch:
images=images,
image_sizes=image_sizes,
split_image_sizes=split_image_sizes,
modalities=modalities
modalities=modalities,
ids=ids,
ids_shape=ids_shape,
)

return batch
Expand Down
34 changes: 30 additions & 4 deletions flagscale/train/train_llava_onevision.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,10 @@ def get_batch(data_iterator):
labels_shape = tensor_parallel.broadcast_data(["labels_shape"], data, torch.int64)[
"labels_shape"
]
ids = tensor_parallel.broadcast_data(["ids"], data, torch.uint8)["ids"]
ids_shape = tensor_parallel.broadcast_data(["ids_shape"], data, torch.int64)[
"ids_shape"
]
images = tensor_parallel.broadcast_data(["images"], data, torch.float32)["images"]
split_image_sizes = tensor_parallel.broadcast_data(
["split_image_sizes"], data, torch.int64
Expand Down Expand Up @@ -229,6 +233,17 @@ def get_batch(data_iterator):
assert start_idx == labels.numel()
labels = labels_list

# ids to list
ids_list = []
start_idx = 0
for shape in ids_shape:
num_elements = torch.prod(shape).item()
sub_tensor = ids[start_idx : start_idx + num_elements].reshape(shape.tolist())
ids_list.append(sub_tensor)
start_idx += num_elements
assert start_idx == ids.numel()
ids = ids_list

# images to list
images_list = []
start_idx = 0
Expand Down Expand Up @@ -288,7 +303,7 @@ def get_batch(data_iterator):
attention_mask = input_ids.ne(tokenizer.pad_token_id)
torch.cuda.nvtx.range_pop()

return input_ids, labels, attention_mask, images, image_sizes, modalities
return input_ids, labels, attention_mask, images, image_sizes, modalities, ids


def pad_sequence(input_ids, batch_first, padding_value, tokenizer):
Expand Down Expand Up @@ -316,7 +331,13 @@ def get_image_token_count():
return num_image_tokens


def loss_func(labels: torch.Tensor, loss_mask: torch.Tensor, logits: torch.Tensor):
def loss_func(
labels: torch.Tensor,
loss_mask: torch.Tensor,
ids,
logits: torch.Tensor,
):
args = get_args()
labels = labels.transpose(0, 1).contiguous() # [b s] => [s b]
logits = logits.transpose(0, 1).contiguous() # [b s h] => [s b h]

Expand All @@ -334,6 +355,11 @@ def loss_func(labels: torch.Tensor, loss_mask: torch.Tensor, logits: torch.Tenso
loss = torch.mean(losses)

# Reduce loss for logging.
if args.skip_train:
assert isinstance(ids, list) and len(ids) == 1
id = "".join([chr(c) for c in ids[0].cpu().numpy()])
print(f"Evaluating id: {id}, loss: {loss.detach().clone().item()}", flush=True)

averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {"lm loss": averaged_loss[0]}

Expand All @@ -354,7 +380,7 @@ def forward_step(data_iterator, model: LLaVAOneVisionModel):

# Get the batch.
timers("batch-generator", log_level=2).start()
input_ids, labels, attention_mask, images, image_sizes, modalities = get_batch(
input_ids, labels, attention_mask, images, image_sizes, modalities, ids = get_batch(
data_iterator
)
if "text" in modalities and ("image" in modalities or "video" in modalities):
Expand All @@ -367,7 +393,7 @@ def forward_step(data_iterator, model: LLaVAOneVisionModel):
input_ids, labels, attention_mask, images, image_sizes, modalities
)

return output_tensor, partial(loss_func, labels, loss_mask)
return output_tensor, partial(loss_func, labels, loss_mask, ids)


def add_multimodal_extra_args(parser):
Expand Down
43 changes: 43 additions & 0 deletions tools/datasets/llava_onevision/filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import os
import re
import json
import argparse
from typing import Dict


def main():
parser = argparse.ArgumentParser(description='Grep id and loss from log files.')
parser.add_argument('--input_dir', type=str, help='Directory to search log files.')
parser.add_argument('--output', type=str, help='Path to save the result.')
args = parser.parse_args()

result_dict: Dict[str, float] = {}
for root, dirs, files in os.walk(args.input_dir):
for file in files:
if file.endswith('.log'):
file_path = os.path.join(root, file)
with open(file_path, 'r') as f:
lines = f.readlines()
for line in lines:
match = re.search(r'Evaluating id: (\d+), loss: ([\d.]+)', line)
if match:
evaluating_id = match.group(1)
loss = float(match.group(2))
if evaluating_id in result_dict:
assert loss == result_dict[evaluating_id]
# Customize filtering rules such as
# if loss < 0.5:
# result_dict[evaluating_id] = loss

# NOTE: No filtering currently, Comment out if Customize
result_dict[evaluating_id] = loss

result = {"ids": list(result_dict.keys())}
assert args.output.endswith(".json")
with open(args.output, 'w') as f:
json.dump(result, f, indent=4)
print("Done")


if __name__ == "__main__":
main()
31 changes: 20 additions & 11 deletions tools/datasets/llava_onevision/llava_ov_wds.py
Original file line number Diff line number Diff line change
Expand Up @@ -1432,6 +1432,8 @@ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
# batch["images"] = torch.stack(images)
# else:
batch["images"] = images
assert "id" in instances[0]
batch["ids"] = [torch.tensor([ord(c) for c in instance["id"]], dtype=torch.uint8) for instance in instances]

if "prompt" in instances[0]:
batch["prompts"] = [instance["prompt"] for instance in instances]
Expand Down Expand Up @@ -1847,15 +1849,22 @@ def make_inputs_require_grad(module, input, output):
if not os.path.exists(output):
os.mkdir(output)
start_time = time.time()
filter_ids = []
filter_json = os.environ.get("FILTER_JSON", "")
if filter_json:
with open(filter_json, 'r') as file:
data = json.load(file)
filter_ids = data["ids"]
with wds.ShardWriter(os.path.join(output, f'llava-ov-{dist.get_rank()}-%d.tar'), maxcount=10000) as shard_writer:
dataloader = trainer.get_train_dataloader()
print(f"sample num: {len(dataloader)}")
global_id = 0
for entry in tqdm(dataloader):
if global_id == 0:
for x in entry.keys():
#print(f"key={x}, type={type(entry[x])}")
pass
assert 'ids' in entry
assert len(entry["ids"]) == 1
id = "".join([chr(c) for c in entry["ids"][0].cpu().numpy()])
if filter_ids and id not in filter_ids:
print(f"The id {id} is filtered out.")
continue

sequence = []
sequence.append(entry['input_ids'][0].cpu())
Expand All @@ -1878,15 +1887,16 @@ def make_inputs_require_grad(module, input, output):
sequence.append([torch.tensor(entry['image_sizes'][0])])
sequence.append([entry['modalities'][0]])
if entry['modalities'][0] == "video":
print(f"Processing video and image_sizes: {entry['image_sizes'][0]}, {images.shape}")
print(f"Processing id {id} video and image_sizes: {entry['image_sizes'][0]}, {images.shape}")
elif entry['modalities'][0] == "text":
print("Processing text.")
print(f"Processing id {id} text.")
elif entry['modalities'][0] == "image":
print("Processing single image.")
print(f"Processing id {id} single image.")
else:
raise ValueError()
else:
# Process images
print(f"Processing id {id} multi images.")
images = []
each_image_shape = None
for image in entry['images']:
Expand All @@ -1912,15 +1922,14 @@ def make_inputs_require_grad(module, input, output):

sequence.append(images)
sequence.append(image_sizes)
sequence.append(modalities)
sequence.append(modalities)

sample = {
"__key__": str(global_id),
"__key__": str(id),
"sequence.pyd": sequence,
}

shard_writer.write(sample)
global_id += 1

print(f"rank {dist.get_rank()} datasets saved to {training_args.output_dir}")

Expand Down
7 changes: 7 additions & 0 deletions tools/datasets/llava_onevision/make_llava_ov_wds.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ set -u
HOSTFILE=$3
set +u

if [ $# -ge 4 ]; then
FILTER_JSON=$4
else
FILTER_JSON=""
fi

echo "BASE_RUN_NAME: ${EXPNAME_PATH}"

CKPT_PATH="./checkpoints"
Expand All @@ -52,6 +58,7 @@ do
export WANDB_MODE=offline && \
export ACCELERATE_CPU_AFFINITY=1 && \
export PYTHONPATH=$LLaVA_NeXT_HOME:$PYTHONPATH && \
export FILTER_JSON=$FILTER_JSON && \
source /root/miniconda3/bin/activate flagscale && \
torchrun --nproc_per_node=8 --nnodes=${NNodes} --node_rank=${rank} --master_addr=${MASTER_ADDR} --master_port=13888 llava_ov_wds.py \
--model_name_or_path ${CKPT_PATH} \
Expand Down

0 comments on commit 35afa95

Please sign in to comment.