diff --git a/src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py b/src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py index a9f480e..0ff1045 100644 --- a/src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py +++ b/src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py @@ -196,21 +196,20 @@ def load_diffusers_checkpoint(self): precision=precision, ) - if len(self.config.unet_checkpoint) > 0: - unet, unet_params = FlaxUNet2DConditionModel.from_pretrained( - self.config.unet_checkpoint, - split_head_dim=self.config.split_head_dim, - norm_num_groups=self.config.norm_num_groups, - attention_kernel=self.config.attention, - flash_block_sizes=flash_block_sizes, - dtype=self.activations_dtype, - weights_dtype=self.weights_dtype, - mesh=self.mesh, - ) - params["unet"] = unet_params - pipeline.unet = unet - params = jax.tree_util.tree_map(lambda x: x.astype(self.config.weights_dtype), params) - + if len(self.config.unet_checkpoint) > 0: + unet, unet_params = FlaxUNet2DConditionModel.from_pretrained( + self.config.unet_checkpoint, + split_head_dim=self.config.split_head_dim, + norm_num_groups=self.config.norm_num_groups, + attention_kernel=self.config.attention, + flash_block_sizes=flash_block_sizes, + dtype=self.activations_dtype, + weights_dtype=self.weights_dtype, + mesh=self.mesh, + ) + params["unet"] = unet_params + pipeline.unet = unet + params = jax.tree_util.tree_map(lambda x: x.astype(self.config.weights_dtype), params) return pipeline, params def save_checkpoint(self, train_step, pipeline, params, train_states): diff --git a/src/maxdiffusion/models/modeling_flax_utils.py b/src/maxdiffusion/models/modeling_flax_utils.py index f113669..3983c22 100644 --- a/src/maxdiffusion/models/modeling_flax_utils.py +++ b/src/maxdiffusion/models/modeling_flax_utils.py @@ -34,6 +34,7 @@ FLAX_WEIGHTS_NAME, HUGGINGFACE_CO_RESOLVE_ENDPOINT, WEIGHTS_NAME, + SAFETENSORS_WEIGHTS_NAME, PushToHubMixin, logging, ) @@ -331,9 +332,12 @@ def from_pretrained( ) if os.path.isdir(pretrained_path_with_subfolder): if from_pt: - if not os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)): + if os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)): + model_file = os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME) + elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, SAFETENSORS_WEIGHTS_NAME)): + model_file = os.path.join(pretrained_path_with_subfolder, SAFETENSORS_WEIGHTS_NAME) + else: raise EnvironmentError(f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_path_with_subfolder} ") - model_file = os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME) elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME)): # Load from a Flax checkpoint model_file = os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME)