diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index 401c6d1ec..c19ae9749 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -142,12 +142,23 @@ def main(args): del model - ort_sess = ort.InferenceSession( - onnx_path, - providers=( - ["CUDAExecutionProvider"] if "CUDAExecutionProvider" in ort.get_available_providers() else ["CPUExecutionProvider"] - ), - ) + if "OpenVINOExecutionProvider" in ort.get_available_providers(): + # requires provider options for gpu support + # fp16 causes nonsense outputs + ort_sess = ort.InferenceSession( + onnx_path, + providers=(["OpenVINOExecutionProvider"]), + provider_options=[{'device_type' : "GPU_FP32"}], + ) + else: + ort_sess = ort.InferenceSession( + onnx_path, + providers=( + ["CUDAExecutionProvider"] if "CUDAExecutionProvider" in ort.get_available_providers() else + ["ROCMExecutionProvider"] if "ROCMExecutionProvider" in ort.get_available_providers() else + ["CPUExecutionProvider"] + ), + ) else: from tensorflow.keras.models import load_model diff --git a/library/ipex/attention.py b/library/ipex/attention.py index 8253c5b17..d989ad53d 100644 --- a/library/ipex/attention.py +++ b/library/ipex/attention.py @@ -122,15 +122,15 @@ def torch_bmm_32_bit(input, mat2, *, out=None): mat2[start_idx:end_idx], out=out ) + torch.xpu.synchronize(input.device) else: return original_torch_bmm(input, mat2, out=out) - torch.xpu.synchronize(input.device) return hidden_states original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention -def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False): +def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs): if query.device.type != "xpu": - return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal) + return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs) do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_sdpa_slice_sizes(query.shape, query.element_size()) # Slice SDPA @@ -153,7 +153,7 @@ def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropo key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attn_mask is not None else attn_mask, - dropout_p=dropout_p, is_causal=is_causal + dropout_p=dropout_p, is_causal=is_causal, **kwargs ) else: hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention( @@ -161,7 +161,7 @@ def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropo key[start_idx:end_idx, start_idx_2:end_idx_2], value[start_idx:end_idx, start_idx_2:end_idx_2], attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask, - dropout_p=dropout_p, is_causal=is_causal + dropout_p=dropout_p, is_causal=is_causal, **kwargs ) else: hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention( @@ -169,9 +169,9 @@ def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropo key[start_idx:end_idx], value[start_idx:end_idx], attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask, - dropout_p=dropout_p, is_causal=is_causal + dropout_p=dropout_p, is_causal=is_causal, **kwargs ) + torch.xpu.synchronize(query.device) else: - return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal) - torch.xpu.synchronize(query.device) + return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs) return hidden_states diff --git a/library/ipex/hijacks.py b/library/ipex/hijacks.py index b1b9ccf0e..65089f39e 100644 --- a/library/ipex/hijacks.py +++ b/library/ipex/hijacks.py @@ -12,7 +12,7 @@ class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument if isinstance(device_ids, list) and len(device_ids) > 1: - logger.error("IPEX backend doesn't support DataParallel on multiple XPU devices") + print("IPEX backend doesn't support DataParallel on multiple XPU devices") return module.to("xpu") def return_null_context(*args, **kwargs): # pylint: disable=unused-argument @@ -42,7 +42,7 @@ def autocast_init(self, device_type, dtype=None, enabled=True, cache_enabled=Non original_interpolate = torch.nn.functional.interpolate @wraps(torch.nn.functional.interpolate) def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments - if antialias or align_corners is not None: + if antialias or align_corners is not None or mode == 'bicubic': return_device = tensor.device return_dtype = tensor.dtype return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode, @@ -216,7 +216,9 @@ def torch_empty(*args, device=None, **kwargs): original_torch_randn = torch.randn @wraps(torch.randn) -def torch_randn(*args, device=None, **kwargs): +def torch_randn(*args, device=None, dtype=None, **kwargs): + if dtype == bytes: + dtype = None if check_device(device): return original_torch_randn(*args, device=return_xpu(device), **kwargs) else: @@ -256,11 +258,11 @@ def torch_Generator(device=None): original_torch_load = torch.load @wraps(torch.load) -def torch_load(f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs): +def torch_load(f, map_location=None, *args, **kwargs): if check_device(map_location): - return original_torch_load(f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs) + return original_torch_load(f, map_location=return_xpu(map_location), *args, **kwargs) else: - return original_torch_load(f, map_location=map_location, pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs) + return original_torch_load(f, map_location=map_location, *args, **kwargs) # Hijack Functions: