Skip to content

Commit

Permalink
Retain alpha in pil_resize
Browse files Browse the repository at this point in the history
Currently the alpha channel is dropped by `pil_resize()` when `--alpha_mask` is supplied and the image width does not exceed the bucket.

This codepath is entered on the last line, here:
```
def trim_and_resize_if_required(
    random_crop: bool, image: np.ndarray, reso, resized_size: Tuple[int, int]
) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int, int, int]]:
    image_height, image_width = image.shape[0:2]
    original_size = (image_width, image_height)  # size before resize

    if image_width != resized_size[0] or image_height != resized_size[1]:
        # リサイズする
        if image_width > resized_size[0] and image_height > resized_size[1]:
            image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA)  # INTER_AREAでやりたいのでcv2でリサイズ
        else:
            image = pil_resize(image, resized_size)
```
  • Loading branch information
emcmanus authored Sep 19, 2024
1 parent b844c70 commit 3957372
Showing 1 changed file with 17 additions and 4 deletions.
21 changes: 17 additions & 4 deletions library/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,13 +305,26 @@ def _convert_float8(byte_tensor, dtype_str, shape):
raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)")

def pil_resize(image, size, interpolation=Image.LANCZOS):
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
# Check if the image has an alpha channel
has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False

# use Pillow resize
if has_alpha:
# Convert BGRA to RGBA
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA))
else:
# Convert BGR to RGB
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))

# Resize the image
resized_pil = pil_image.resize(size, interpolation)

# return cv2 image
resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGB2BGR)
# Convert back to cv2 format
if has_alpha:
# Convert RGBA to BGRA
resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGBA2BGRA)
else:
# Convert RGB to BGR
resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGB2BGR)

return resized_cv2

Expand Down

0 comments on commit 3957372

Please sign in to comment.