Skip to content

Commit

Permalink
Add support for loading multiple datasets (Sygil-Dev#78)
Browse files Browse the repository at this point in the history
They **MUST** have the same image and text column for this to work, you
can split them by inserting `|` between dataset names (no spaces between
the `|`, this will break it, I'll probably add a `.strip()` later)
  • Loading branch information
ZeroCool940711 authored Oct 2, 2023
2 parents 952d778 + ac99004 commit a1ca68d
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 32 deletions.
60 changes: 46 additions & 14 deletions train_muse_maskgit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch
import transformers
from accelerate.utils import ProjectConfiguration
from datasets import load_dataset
from datasets import concatenate_datasets, load_dataset
from diffusers.optimization import SchedulerType, get_scheduler
from omegaconf import OmegaConf
from rich import inspect
Expand Down Expand Up @@ -237,7 +237,8 @@ def decompress_pickle(file):
"--dataset_name",
type=str,
default=None,
help="ID of HuggingFace dataset to use (cannot be used with --train_data_dir)",
help="ID of HuggingFace dataset to use (cannot be used with --train_data_dir, use multiple by splitting with '|', "
"they must have the same image column and text column)",
)
parser.add_argument(
"--hf_split_name",
Expand Down Expand Up @@ -605,18 +606,49 @@ def main():
save_path=args.dataset_save_path,
)
elif args.dataset_name is not None:
dataset = load_dataset(
args.dataset_name,
streaming=args.streaming,
cache_dir=args.cache_path,
save_infos=True,
split="train",
)
if args.streaming:
if args.cache_path:
dataset = load_dataset(args.dataset_name, cache_dir=args.cache_path)[args.hf_split_name]
else:
dataset = load_dataset(args.dataset_name)[args.hf_split_name]
if "|" in args.dataset_name:
loaded_datasets = []
for name in args.dataset_name.split("|"):
accelerator.print(f"Loading {name}")
data_to_add = load_dataset(
name,
streaming=args.streaming,
cache_dir=args.cache_path,
save_infos=True,
split="train",
)

data_to_add.remove_columns(
[
col
for col in data_to_add.column_names
if col != args.caption_column or col != args.image_column
]
)

loaded_datasets.append(data_to_add)

try:
dataset = concatenate_datasets(loaded_datasets)
except ValueError:
raise UserWarning("Failed concatenating dataset... Make sure they use the same columns!")

else:
dataset = load_dataset(
args.dataset_name,
streaming=args.streaming,
cache_dir=args.cache_path,
save_infos=True,
split="train",
)

if args.streaming:
if args.cache_path:
dataset = load_dataset(args.dataset_name, cache_dir=args.cache_path)[
args.hf_split_name
]
else:
dataset = load_dataset(args.dataset_name)[args.hf_split_name]
else:
raise ValueError("You must pass either train_data_dir or dataset_name (but not both)")

Expand Down
64 changes: 46 additions & 18 deletions train_muse_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
from dataclasses import dataclass
from typing import Optional, Union

import wandb
from accelerate.utils import ProjectConfiguration
from datasets import load_dataset, Dataset, Image
from datasets import Dataset, Image, concatenate_datasets, load_dataset
from omegaconf import OmegaConf

import wandb
from muse_maskgit_pytorch import (
VQGanVAE,
VQGanVAETaming,
Expand Down Expand Up @@ -163,7 +163,8 @@
"--dataset_name",
type=str,
default=None,
help="Name of the huggingface dataset used.",
help="ID of HuggingFace dataset to use (cannot be used with --train_data_dir, use multiple by splitting with '|', "
"they must have the same image column and text column)",
)
parser.add_argument(
"--hf_split_name",
Expand Down Expand Up @@ -409,7 +410,7 @@ def main():
args = parser.parse_args(namespace=Arguments())

if args.config_path:
accelerator.print("Using config file and ignoring CLI args")
print("Using config file and ignoring CLI args")

try:
conf = OmegaConf.load(args.config_path)
Expand All @@ -420,10 +421,10 @@ def main():
try:
args_to_convert[key] = conf[key]
except KeyError:
accelerator.print(f"Error parsing config - {key}: {conf[key]} | Using default or parsed")
print(f"Error parsing config - {key}: {conf[key]} | Using default or parsed")

except FileNotFoundError:
accelerator.print("Could not find config, using default and parsed values...")
print("Could not find config, using default and parsed values...")

project_config = ProjectConfiguration(
project_dir=args.logging_dir if args.logging_dir else os.path.join(args.results_dir, "logs"),
Expand Down Expand Up @@ -464,18 +465,43 @@ def main():
save=not args.no_cache,
)
elif args.dataset_name:
if args.cache_path:
dataset = load_dataset(args.dataset_name, streaming=args.streaming, cache_dir=args.cache_path)[
"train"
]
if "|" in args.dataset_name:
loaded_datasets = []
for name in args.dataset_name.split("|"):
accelerator.print(f"Loading {name}")
data_to_add = load_dataset(
name,
streaming=args.streaming,
cache_dir=args.cache_path,
save_infos=True,
split="train",
)

data_to_add.remove_columns(
[
col
for col in data_to_add.column_names
if col != args.caption_column or col != args.image_column
]
)

loaded_datasets.append(data_to_add)

try:
dataset = concatenate_datasets(loaded_datasets)
except ValueError:
raise UserWarning("Failed concatenating dataset... Make sure they use the same columns!")

else:
dataset = load_dataset(args.dataset_name, streaming=args.streaming, cache_dir=args.cache_path)[
"train"
]
if args.streaming:
if dataset.info.dataset_size is None:
accelerator.print("Dataset doesn't support streaming, disabling streaming")
args.streaming = False
dataset = load_dataset(
args.dataset_name,
streaming=args.streaming,
cache_dir=args.cache_path,
save_infos=True,
split="train",
)

if args.streaming:
if args.cache_path:
dataset = load_dataset(args.dataset_name, cache_dir=args.cache_path)[args.hf_split_name]
else:
Expand Down Expand Up @@ -610,7 +636,9 @@ def main():
filepaths.append(os.path.join(root, file))

if not filepaths:
print(f"No images with extensions {extensions} found in {args.validation_folder_at_end_of_epoch}.")
print(
f"No images with extensions {extensions} found in {args.validation_folder_at_end_of_epoch}."
)
exit(1)

epoch_validation_dataset = Dataset.from_dict({"image": filepaths}).cast_column("image", Image())
Expand Down

0 comments on commit a1ca68d

Please sign in to comment.