Skip to content

Commit

Permalink
fix dropout dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
jazcollins committed Sep 25, 2023
1 parent 06b2f2f commit 77b0099
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions diffusion/datasets/image_caption.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,11 @@ def __getitem__(self, index):
# Microconditioning dropout as in Stability repo
# https://github.com/Stability-AI/generative-models/blob/477d8b9a7730d9b2e92b326a770c0420d00308c9/sgm/modules/encoders/modules.py#L151-L160
if torch.rand(1) < self.microcond_drop_prob:
out['cond_crops_coords_top_left'] = out['cond_crops_coords_top_left'] * 0.0
out['cond_crops_coords_top_left'] = out['cond_crops_coords_top_left'] * 0
if torch.rand(1) < self.microcond_drop_prob:
out['cond_original_size'] = out['cond_original_size'] * 0.0
out['cond_original_size'] = out['cond_original_size'] * 0
if torch.rand(1) < self.microcond_drop_prob:
out['cond_target_size'] = out['cond_target_size'] * 0.0
out['cond_target_size'] = out['cond_target_size'] * 0
else:
crop_top, crop_left, image_height, image_width = None, None, None, None
if self.transform is not None:
Expand Down Expand Up @@ -232,6 +232,9 @@ def build_streaming_image_caption_dataloader(
transform = transforms.Compose(transform)
assert isinstance(transform, Callable)

import streaming
streaming.base.util.clean_stale_shared_memory()

dataset = StreamingImageCaptionDataset(
streams=streams,
tokenizer_name_or_path=tokenizer_name_or_path,
Expand Down

0 comments on commit 77b0099

Please sign in to comment.