Skip to content

Commit

Permalink
Trivial fix for undefined symbol in train_dreambooth.py (huggingface#…
Browse files Browse the repository at this point in the history
…1598)

easy fix for undefined name in train_dreambooth.py

import_model_class_from_model_name_or_path loads a pretrained model
and refers to args.revision in a context where args is undefined. I modified
the function to take revision as an argument and modified the invocation
of the function to pass in the revision from args. Seems like this was caused
by a cut and paste.
  • Loading branch information
bcsherma authored Dec 7, 2022
1 parent eb1abee commit 326de41
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions examples/dreambooth/train_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@
logger = get_logger(__name__)


def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str):
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
text_encoder_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path,
subfolder="text_encoder",
revision=args.revision,
revision=revision,
)
model_class = text_encoder_config.architectures[0]

Expand Down Expand Up @@ -469,7 +469,7 @@ def main(args):
)

# import correct text encoder class
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path)
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)

# Load models and create wrapper for stable diffusion
text_encoder = text_encoder_cls.from_pretrained(
Expand Down

0 comments on commit 326de41

Please sign in to comment.