diff --git a/muse_maskgit_pytorch/dataset.py b/muse_maskgit_pytorch/dataset.py index 753b351..431583f 100644 --- a/muse_maskgit_pytorch/dataset.py +++ b/muse_maskgit_pytorch/dataset.py @@ -167,16 +167,20 @@ def __getitem__(self, index): else: text = descriptions # max length from the paper - encoded = self.tokenizer.batch_encode_plus( - [str(text)], - return_tensors="pt", - padding="max_length", - max_length=MAX_LENGTH, - truncation=True, - ) + if self.tokenizer is not None: + encoded = self.tokenizer.batch_encode_plus( + [str(text)], + return_tensors="pt", + padding="max_length", + max_length=MAX_LENGTH, + truncation=True, + ) - input_ids = encoded.input_ids - attn_mask = encoded.attention_mask + input_ids = encoded.input_ids + attn_mask = encoded.attention_mask + else: + input_ids = [] + attn_mask = [] if self.using_taming: if self.embeds: @@ -242,16 +246,20 @@ def __getitem__(self, index): else: text = descriptions # max length from the paper - encoded = self.tokenizer.batch_encode_plus( - [str(text)], - return_tensors="pt", - padding="max_length", - max_length=MAX_LENGTH, - truncation=True, - ) + if self.tokenizer is not None: + encoded = self.tokenizer.batch_encode_plus( + [str(text)], + return_tensors="pt", + padding="max_length", + max_length=MAX_LENGTH, + truncation=True, + ) - input_ids = encoded.input_ids - attn_mask = encoded.attention_mask + input_ids = encoded.input_ids + attn_mask = encoded.attention_mask + else: + input_ids = [] + attn_mask = [] if self.using_taming: if self.embeds: @@ -338,16 +346,21 @@ def __getitem__(self, index): embed = self.embeds[index] # max length from the paper - encoded = self.tokenizer.batch_encode_plus( - [str(text)], - return_tensors="pt", - padding="max_length", - max_length=MAX_LENGTH, - truncation=True, - ) + if self.tokenizer is not None: + encoded = self.tokenizer.batch_encode_plus( + [str(text)], + return_tensors="pt", + padding="max_length", + max_length=MAX_LENGTH, + truncation=True, + ) + + input_ids = encoded.input_ids + attn_mask = encoded.attention_mask + else: + input_ids = [] + attn_mask = [] - input_ids = encoded.input_ids - attn_mask = encoded.attention_mask if self.using_taming: if self.embeds: return self.transform(image) - 0.5, input_ids[0], attn_mask[0], embed, text diff --git a/muse_maskgit_pytorch/muse_maskgit_pytorch.py b/muse_maskgit_pytorch/muse_maskgit_pytorch.py index 2b3fd7b..0199b2c 100644 --- a/muse_maskgit_pytorch/muse_maskgit_pytorch.py +++ b/muse_maskgit_pytorch/muse_maskgit_pytorch.py @@ -185,6 +185,7 @@ def __init__( self.norm = LayerNorm(dim) self.use_clip = use_clip + self.tokenizer = None self.dim_out = default(dim_out, num_tokens) self.to_logits = nn.Linear(dim, self.dim_out, bias=False) diff --git a/train_muse_maskgit.py b/train_muse_maskgit.py index b8efe29..c573a9d 100644 --- a/train_muse_maskgit.py +++ b/train_muse_maskgit.py @@ -824,6 +824,9 @@ def main(): else: embeds = [] + if args.use_metaclip: + transformer.tokenizer = None + # Create the dataset objects with accelerator.main_process_first(): if args.no_cache and args.train_data_dir: @@ -1003,8 +1006,8 @@ def main(): clip_precision = args.mixed_precision clip = open_clip.create_model_and_transforms( - "ViT-L-14", - pretrained="metaclip/l14_400m.pt", + "convnext_base_w", + pretrained="laion2b_s13b_b82k_augreg", cache_dir=args.cache_path, precision=clip_precision, device=accelerator.device,