From 3957372ded6fda20553acaf169993a422b829bdc Mon Sep 17 00:00:00 2001 From: Ed McManus Date: Thu, 19 Sep 2024 14:30:03 -0700 Subject: [PATCH] Retain alpha in `pil_resize` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Currently the alpha channel is dropped by `pil_resize()` when `--alpha_mask` is supplied and the image width does not exceed the bucket. This codepath is entered on the last line, here: ``` def trim_and_resize_if_required( random_crop: bool, image: np.ndarray, reso, resized_size: Tuple[int, int] ) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int, int, int]]: image_height, image_width = image.shape[0:2] original_size = (image_width, image_height) # size before resize if image_width != resized_size[0] or image_height != resized_size[1]: # リサイズする if image_width > resized_size[0] and image_height > resized_size[1]: image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ else: image = pil_resize(image, resized_size) ``` --- library/utils.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/library/utils.py b/library/utils.py index a0bb19650..2171c7190 100644 --- a/library/utils.py +++ b/library/utils.py @@ -305,13 +305,26 @@ def _convert_float8(byte_tensor, dtype_str, shape): raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)") def pil_resize(image, size, interpolation=Image.LANCZOS): - pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + # Check if the image has an alpha channel + has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False - # use Pillow resize + if has_alpha: + # Convert BGRA to RGBA + pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)) + else: + # Convert BGR to RGB + pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + + # Resize the image resized_pil = pil_image.resize(size, interpolation) - # return cv2 image - resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGB2BGR) + # Convert back to cv2 format + if has_alpha: + # Convert RGBA to BGRA + resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGBA2BGRA) + else: + # Convert RGB to BGR + resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGB2BGR) return resized_cv2