Skip to content

Commit

Permalink
Merge branch 'dev' of https://github.com/kohya-ss/sd-scripts into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Mar 30, 2024
2 parents f1f30ab + b86af67 commit ae3f625
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 20 deletions.
23 changes: 17 additions & 6 deletions finetune/tag_images_by_wd14_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,23 @@ def main(args):

del model

ort_sess = ort.InferenceSession(
onnx_path,
providers=(
["CUDAExecutionProvider"] if "CUDAExecutionProvider" in ort.get_available_providers() else ["CPUExecutionProvider"]
),
)
if "OpenVINOExecutionProvider" in ort.get_available_providers():
# requires provider options for gpu support
# fp16 causes nonsense outputs
ort_sess = ort.InferenceSession(
onnx_path,
providers=(["OpenVINOExecutionProvider"]),
provider_options=[{'device_type' : "GPU_FP32"}],
)
else:
ort_sess = ort.InferenceSession(
onnx_path,
providers=(
["CUDAExecutionProvider"] if "CUDAExecutionProvider" in ort.get_available_providers() else
["ROCMExecutionProvider"] if "ROCMExecutionProvider" in ort.get_available_providers() else
["CPUExecutionProvider"]
),
)
else:
from tensorflow.keras.models import load_model

Expand Down
16 changes: 8 additions & 8 deletions library/ipex/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,15 +122,15 @@ def torch_bmm_32_bit(input, mat2, *, out=None):
mat2[start_idx:end_idx],
out=out
)
torch.xpu.synchronize(input.device)
else:
return original_torch_bmm(input, mat2, out=out)
torch.xpu.synchronize(input.device)
return hidden_states

original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs):
if query.device.type != "xpu":
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_sdpa_slice_sizes(query.shape, query.element_size())

# Slice SDPA
Expand All @@ -153,25 +153,25 @@ def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropo
key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal
dropout_p=dropout_p, is_causal=is_causal, **kwargs
)
else:
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention(
query[start_idx:end_idx, start_idx_2:end_idx_2],
key[start_idx:end_idx, start_idx_2:end_idx_2],
value[start_idx:end_idx, start_idx_2:end_idx_2],
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal
dropout_p=dropout_p, is_causal=is_causal, **kwargs
)
else:
hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention(
query[start_idx:end_idx],
key[start_idx:end_idx],
value[start_idx:end_idx],
attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal
dropout_p=dropout_p, is_causal=is_causal, **kwargs
)
torch.xpu.synchronize(query.device)
else:
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
torch.xpu.synchronize(query.device)
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
return hidden_states
14 changes: 8 additions & 6 deletions library/ipex/hijacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods
def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument
if isinstance(device_ids, list) and len(device_ids) > 1:
logger.error("IPEX backend doesn't support DataParallel on multiple XPU devices")
print("IPEX backend doesn't support DataParallel on multiple XPU devices")
return module.to("xpu")

def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
Expand Down Expand Up @@ -42,7 +42,7 @@ def autocast_init(self, device_type, dtype=None, enabled=True, cache_enabled=Non
original_interpolate = torch.nn.functional.interpolate
@wraps(torch.nn.functional.interpolate)
def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
if antialias or align_corners is not None:
if antialias or align_corners is not None or mode == 'bicubic':
return_device = tensor.device
return_dtype = tensor.dtype
return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode,
Expand Down Expand Up @@ -216,7 +216,9 @@ def torch_empty(*args, device=None, **kwargs):

original_torch_randn = torch.randn
@wraps(torch.randn)
def torch_randn(*args, device=None, **kwargs):
def torch_randn(*args, device=None, dtype=None, **kwargs):
if dtype == bytes:
dtype = None
if check_device(device):
return original_torch_randn(*args, device=return_xpu(device), **kwargs)
else:
Expand Down Expand Up @@ -256,11 +258,11 @@ def torch_Generator(device=None):

original_torch_load = torch.load
@wraps(torch.load)
def torch_load(f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs):
def torch_load(f, map_location=None, *args, **kwargs):
if check_device(map_location):
return original_torch_load(f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs)
return original_torch_load(f, map_location=return_xpu(map_location), *args, **kwargs)
else:
return original_torch_load(f, map_location=map_location, pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs)
return original_torch_load(f, map_location=map_location, *args, **kwargs)


# Hijack Functions:
Expand Down

0 comments on commit ae3f625

Please sign in to comment.