Skip to content

Commit

Permalink
[UPDATE] resize_max_res() takes 4-dim input
Browse files Browse the repository at this point in the history
  • Loading branch information
markkua committed May 24, 2024
1 parent d129f6c commit c91a70a
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
12 changes: 7 additions & 5 deletions marigold/marigold_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,14 +222,15 @@ def __call__(
input_image = input_image.convert("RGB")
# convert to torch tensor [H, W, rgb] -> [rgb, H, W]
rgb = pil_to_tensor(input_image)
rgb = rgb.unsqueeze(0) # [1, rgb, H, W]
elif isinstance(input_image, torch.Tensor):
rgb = input_image.squeeze()
rgb = input_image
else:
raise TypeError(f"Unknown input type: {type(input_image) = }")
input_size = rgb.shape
assert (
3 == rgb.dim() and 3 == input_size[0]
), f"Wrong input shape {input_size}, expected [rgb, H, W]"
4 == rgb.dim() and 3 == input_size[-3]
), f"Wrong input shape {input_size}, expected [1, rgb, H, W]"

# Resize image
if processing_res > 0:
Expand All @@ -246,7 +247,7 @@ def __call__(

# ----------------- Predicting depth -----------------
# Batch repeated input image
duplicated_rgb = torch.stack([rgb_norm] * ensemble_size)
duplicated_rgb = rgb_norm.expand(ensemble_size, -1, -1, -1)
single_rgb_dataset = TensorDataset(duplicated_rgb)
if batch_size > 0:
_bs = batch_size
Expand Down Expand Up @@ -287,6 +288,7 @@ def __call__(
depth_preds,
scale_invariant=self.scale_invariant,
shift_invariant=self.shift_invariant,
max_res=50,
**(ensemble_kwargs or {}),
)
else:
Expand All @@ -297,7 +299,7 @@ def __call__(
if match_input_res:
depth_pred = resize(
depth_pred,
input_size[1:],
input_size[-2:],
interpolation=resample_method,
antialias=True,
)
Expand Down
7 changes: 4 additions & 3 deletions marigold/util/image_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def resize_max_res(
Args:
img (`torch.Tensor`):
Image tensor to be resized.
Image tensor to be resized. Expected shape: [B, C, H, W]
max_edge_resolution (`int`):
Maximum edge length (pixel).
resample_method (`PIL.Image.Resampling`):
Expand All @@ -95,8 +95,9 @@ def resize_max_res(
Returns:
`torch.Tensor`: Resized image.
"""
assert 3 == img.dim()
_, original_height, original_width = img.shape
assert 4 == img.dim(), f"Invalid input shape {img.shape}"

original_height, original_width = img.shape[-2:]
downscale_factor = min(
max_edge_resolution / original_width, max_edge_resolution / original_height
)
Expand Down

0 comments on commit c91a70a

Please sign in to comment.