Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
corystephenson-db committed Jul 24, 2024
1 parent 8b54a6e commit f673a4d
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
4 changes: 4 additions & 0 deletions diffusion/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from diffusion.datasets.coco import StreamingCOCOCaption, build_streaming_cocoval_dataloader
from diffusion.datasets.image_caption import StreamingImageCaptionDataset, build_streaming_image_caption_dataloader
from diffusion.datasets.image_caption_latents import (StreamingImageCaptionLatentsDataset,
build_streaming_image_caption_latents_dataloader)
from diffusion.datasets.laion import StreamingLAIONDataset, build_streaming_laion_dataloader
from diffusion.datasets.synthetic_image_caption import (SyntheticImageCaptionDataset,
build_synthetic_image_caption_dataloader)
Expand All @@ -16,6 +18,8 @@
'StreamingCOCOCaption',
'build_streaming_image_caption_dataloader',
'StreamingImageCaptionDataset',
'build_streaming_image_caption_latents_dataloader',
'StreamingImageCaptionLatentsDataset',
'build_synthetic_image_caption_dataloader',
'SyntheticImageCaptionDataset',
]
6 changes: 3 additions & 3 deletions diffusion/datasets/image_caption_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
log = logging.getLogger(__name__)


class StreamingTextLatentsDataset(StreamingDataset):
class StreamingImageCaptionLatentsDataset(StreamingDataset):
"""Streaming dataset for image-caption datasets with pre-computed text latents.
Args:
Expand Down Expand Up @@ -139,7 +139,7 @@ def __getitem__(self, index):
return out


def build_streaming_text_latents_dataloader(
def build_streaming_image_caption_latents_dataloader(
remote: Union[str, List],
batch_size: int,
local: Optional[Union[str, List]] = None,
Expand Down Expand Up @@ -221,7 +221,7 @@ def build_streaming_text_latents_dataloader(
transform = transforms.Compose(transform)
assert isinstance(transform, Callable)

dataset = StreamingTextLatentsDataset(
dataset = StreamingImageCaptionLatentsDataset(
streams=streams,
caption_drop_prob=caption_drop_prob,
microcond_drop_prob=microcond_drop_prob,
Expand Down

0 comments on commit f673a4d

Please sign in to comment.