Skip to content

Commit

Permalink
Add a colormap option to ShowImages
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 553126544
  • Loading branch information
Qwlouse authored and The kauldron Authors committed Aug 2, 2023
1 parent 297740f commit 6e08503
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion kauldron/summaries/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class ShowImages(ImageSummary):
height: Optional[int] = None
in_vrange: Optional[tuple[float, float]] = None
convert_to_float: bool = False
cmap: str | None = None

def gather_kwargs(self, context: Any) -> dict[str, Images]:
# optimize gather_kwargs to only return num_images many images
Expand All @@ -101,7 +102,7 @@ def gather_kwargs(self, context: Any) -> dict[str, Images]:
return {"images": images}

@typechecked
def get_images(self, images: Images) -> Float["n _h _w c"]:
def get_images(self, images: Images) -> Float["n _h _w _c"]:
# flatten batch dimensions
images = einops.rearrange(images, "... h w c -> (...) h w c")
images = images[: self.num_images]
Expand All @@ -111,6 +112,15 @@ def get_images(self, images: Images) -> Float["n _h _w c"]:
images = (images - vmin) / (vmax - vmin)
# convert to float
images = media.to_type(images, np.float32)

if self.cmap is not None:
if not isinstance(images, Float["n h w 1"]):
raise ValueError(
"Colormap only supported for single channel inputs (got"
f" {images.shape})"
)
images = media.to_rgb(images[..., 0], cmap=self.cmap)

# always clip to avoid display problems in TB and Datatables
images = np.clip(images, 0.0, 1.0)
# maybe resize
Expand Down

0 comments on commit 6e08503

Please sign in to comment.