diff --git a/marigold/marigold_pipeline.py b/marigold/marigold_pipeline.py index 15b8283..9410b32 100644 --- a/marigold/marigold_pipeline.py +++ b/marigold/marigold_pipeline.py @@ -222,14 +222,15 @@ def __call__( input_image = input_image.convert("RGB") # convert to torch tensor [H, W, rgb] -> [rgb, H, W] rgb = pil_to_tensor(input_image) + rgb = rgb.unsqueeze(0) # [1, rgb, H, W] elif isinstance(input_image, torch.Tensor): - rgb = input_image.squeeze() + rgb = input_image else: raise TypeError(f"Unknown input type: {type(input_image) = }") input_size = rgb.shape assert ( - 3 == rgb.dim() and 3 == input_size[0] - ), f"Wrong input shape {input_size}, expected [rgb, H, W]" + 4 == rgb.dim() and 3 == input_size[-3] + ), f"Wrong input shape {input_size}, expected [1, rgb, H, W]" # Resize image if processing_res > 0: @@ -246,7 +247,7 @@ def __call__( # ----------------- Predicting depth ----------------- # Batch repeated input image - duplicated_rgb = torch.stack([rgb_norm] * ensemble_size) + duplicated_rgb = rgb_norm.expand(ensemble_size, -1, -1, -1) single_rgb_dataset = TensorDataset(duplicated_rgb) if batch_size > 0: _bs = batch_size @@ -287,6 +288,7 @@ def __call__( depth_preds, scale_invariant=self.scale_invariant, shift_invariant=self.shift_invariant, + max_res=50, **(ensemble_kwargs or {}), ) else: @@ -297,7 +299,7 @@ def __call__( if match_input_res: depth_pred = resize( depth_pred, - input_size[1:], + input_size[-2:], interpolation=resample_method, antialias=True, ) diff --git a/marigold/util/image_util.py b/marigold/util/image_util.py index 9924bab..82078fe 100644 --- a/marigold/util/image_util.py +++ b/marigold/util/image_util.py @@ -86,7 +86,7 @@ def resize_max_res( Args: img (`torch.Tensor`): - Image tensor to be resized. + Image tensor to be resized. Expected shape: [B, C, H, W] max_edge_resolution (`int`): Maximum edge length (pixel). resample_method (`PIL.Image.Resampling`): @@ -95,8 +95,9 @@ def resize_max_res( Returns: `torch.Tensor`: Resized image. """ - assert 3 == img.dim() - _, original_height, original_width = img.shape + assert 4 == img.dim(), f"Invalid input shape {img.shape}" + + original_height, original_width = img.shape[-2:] downscale_factor = min( max_edge_resolution / original_width, max_edge_resolution / original_height )