Skip to content

Commit

Permalink
improve color picking strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
lorenzocerrone committed Nov 9, 2024
1 parent 55eb5eb commit 51dabac
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 25 deletions.
91 changes: 67 additions & 24 deletions src/ngio/ngff_meta/fractal_image_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""

from collections.abc import Collection
from difflib import SequenceMatcher
from enum import Enum
from typing import Any, TypeVar

Expand Down Expand Up @@ -49,15 +50,50 @@ class NgioColors(str, Enum):
cyan = "00FFFF"
gray = "808080"
green = "00FF00"
random = "random"

def random_pick(self) -> "NgioColors":
"""Pick a random color."""
available_colors = [color for color in NgioColors if color != "random"]
return available_colors[np.random.randint(0, len(available_colors))]
@staticmethod
def semi_random_pick(channel_name: str | None = None) -> "NgioColors":
"""Try to fuzzy match the color to the channel name.

def valid_hex_color(v: str) -> str:
- If a channel name is given will try to match the channel name to the color.
- If name has the paatern 'channel_x' cyclic rotate over a list of colors
[cyan, magenta, yellow, green]
- If no channel name is given will return a random color.
"""
available_colors = NgioColors._member_names_

if channel_name is None:
# Purely random color
color_str = available_colors[np.random.randint(0, len(available_colors))]
return NgioColors.__members__[color_str]

if channel_name.startswith("channel_"):
# Rotate over a list of colors
defaults_colors = [
NgioColors.cyan,
NgioColors.magenta,
NgioColors.yellow,
NgioColors.green,
]

try:
index = int(channel_name.split("_")[-1]) % len(defaults_colors)
return defaults_colors[index]
except ValueError:
# If the name of the channel is something like
# channel_dapi this will fail an proceed to the
# standard fuzzy match
pass

similarity = {}
for color in available_colors:
# try to match the color to the channel name
similarity[color] = SequenceMatcher(None, channel_name, color).ratio()
color_str = max(similarity, key=similarity.get)
return NgioColors.__members__[color_str]


def valid_hex_color(v: str) -> bool:
"""Validate a hexadecimal color.
Check that `color` is made of exactly six elements which are letters
Expand All @@ -70,15 +106,12 @@ def valid_hex_color(v: str) -> str:
- Tommaso Comparin <[email protected]>
"""
if len(v) != 6:
raise ValueError(f'color must have length 6 (given: "{v}")')
return False
allowed_characters = "abcdefABCDEF0123456789"
for character in v:
if character not in allowed_characters:
raise ValueError(
"color must only include characters from "
f'"{allowed_characters}" (given: "{v}")'
)
return v
return False
return True


class ChannelVisualisation(BaseWithExtraFields):
Expand All @@ -95,15 +128,15 @@ class ChannelVisualisation(BaseWithExtraFields):
active(bool): Whether the channel is active.
"""

color: str | NgioColors = NgioColors.random
color: str | NgioColors | None = Field(default=None, validate_default=True)
min: int | float = 0
max: int | float = 65535
start: int | float = 0
end: int | float = 65535
active: bool = True

@classmethod
@field_validator("color", mode="after")
@classmethod
def validate_color(cls, value: str | NgioColors) -> str:
"""Color validator.
Expand All @@ -112,18 +145,22 @@ def validate_color(cls, value: str | NgioColors) -> str:
- A color name.
- A NgioColors element.
"""
if value in NgioColors:
value = NgioColors(value)
if value == NgioColors.random:
value = value.random_pick()
if value is None:
return NgioColors.semi_random_pick().value
if isinstance(value, str) and valid_hex_color(value):
return value
elif isinstance(value, str):
value_lower = value.lower()
return NgioColors.semi_random_pick(value_lower).value
elif isinstance(value, NgioColors):
return value.value

return valid_hex_color(value)
else:
raise ValueError("Invalid color value.")

@classmethod
def lazy_init(
cls,
color: str = "random",
color: str | NgioColors | None = None,
start: int | float | None = None,
end: int | float | None = None,
data_type: Any = np.uint16,
Expand All @@ -138,7 +175,7 @@ def lazy_init(
"""
start = start if start is not None else np.iinfo(data_type).min
end = end if end is not None else np.iinfo(data_type).max
return cls(
return ChannelVisualisation(
color=color,
min=np.iinfo(data_type).min,
max=np.iinfo(data_type).max,
Expand Down Expand Up @@ -166,7 +203,7 @@ def lazy_init(
cls,
label: str,
wavelength_id: str | None = None,
color: str = "random",
color: str | NgioColors | None = None,
start: int | float | None = None,
end: int | float | None = None,
data_type: Any = np.uint16,
Expand All @@ -177,10 +214,16 @@ def lazy_init(
label(str): The label of the channel.
wavelength_id(str | None): The wavelength ID of the channel.
color(str): The color of the channel in hexadecimal format or a color name.
If None, the color will be picked based on the label.
start(int | float | None): The start value of the channel.
end(int | float | None): The end value of the channel.
data_type(Any): The data type of the channel.
"""
if color is None:
# If no color is provided, try to pick a color based on the label
# See the NgioColors.semi_random_pick method for more details.
color = label

channel_visualization = ChannelVisualisation.lazy_init(
color=color, start=start, end=end, data_type=data_type
)
Expand Down
4 changes: 3 additions & 1 deletion src/ngio/ngff_meta/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,9 @@ def create_image_metadata(
)

if channel_visualization is None:
channel_visualization = [ChannelVisualisation() for _ in channel_labels]
channel_visualization = [
ChannelVisualisation(color=label) for label in channel_labels
]
else:
if len(channel_visualization) != len(channel_labels):
raise ValueError(
Expand Down

0 comments on commit 51dabac

Please sign in to comment.