diff --git a/omni_speech/model/builder.py b/omni_speech/model/builder.py index c4f6df0..45f0e16 100644 --- a/omni_speech/model/builder.py +++ b/omni_speech/model/builder.py @@ -52,7 +52,7 @@ def load_pretrained_model(model_path, model_base, is_lora=False, s2s=False, load model = model_cls.from_pretrained(model_base, low_cpu_mem_usage=False, config=lora_cfg_pretrained, **kwargs) print('Loading additional OmniSpeech weights...') if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')): - non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu') + non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu', weights_only=True) non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()} if any(k.startswith('model.model.') for k in non_lora_trainables): non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()} @@ -70,7 +70,7 @@ def load_pretrained_model(model_path, model_base, is_lora=False, s2s=False, load cfg_pretrained = AutoConfig.from_pretrained(model_path) model = model_cls.from_pretrained(model_base, low_cpu_mem_usage=False, config=cfg_pretrained, **kwargs) - speech_projector_weights = torch.load(os.path.join(model_path, 'speech_projector.bin'), map_location='cpu') + speech_projector_weights = torch.load(os.path.join(model_path, 'speech_projector.bin'), map_location='cpu', weights_only=True) speech_projector_weights = {k: v.to(torch.float16) for k, v in speech_projector_weights.items()} model.load_state_dict(speech_projector_weights, strict=False) model = model.to(device=device) diff --git a/omni_speech/model/omni_speech_arch.py b/omni_speech/model/omni_speech_arch.py index f027566..cd8e033 100644 --- a/omni_speech/model/omni_speech_arch.py +++ b/omni_speech/model/omni_speech_arch.py @@ -60,7 +60,7 @@ def initialize_speech_modules(self, model_args, fsdp=None): p.requires_grad = True if model_args.pretrain_speech_projector is not None: - pretrain_speech_projector_weights = torch.load(model_args.pretrain_speech_projector, map_location='cpu') + pretrain_speech_projector_weights = torch.load(model_args.pretrain_speech_projector, map_location='cpu', weights_only=True) def get_w(weights, keyword): return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}