Skip to content

Commit

Permalink
Add label support to entity registry
Browse files Browse the repository at this point in the history
  • Loading branch information
frenck committed Apr 16, 2022
1 parent 858865e commit 6b7be04
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 1 deletion.
32 changes: 31 additions & 1 deletion homeassistant/helpers/entity_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
_LOGGER = logging.getLogger(__name__)

STORAGE_VERSION_MAJOR = 1
STORAGE_VERSION_MINOR = 6
STORAGE_VERSION_MINOR = 7
STORAGE_KEY = "core.entity_registry"

# Attributes relevant to describing entity
Expand Down Expand Up @@ -109,6 +109,7 @@ class RegistryEntry:
hidden_by: RegistryEntryHider | None = attr.ib(default=None)
icon: str | None = attr.ib(default=None)
id: str = attr.ib(factory=uuid_util.random_uuid_hex)
labels: set[str] = attr.ib(factory=set)
name: str | None = attr.ib(default=None)
options: Mapping[str, Mapping[str, Any]] = attr.ib(
default=None, converter=attr.converters.default_if_none(factory=dict) # type: ignore[misc]
Expand Down Expand Up @@ -328,6 +329,7 @@ def async_get_or_create(
config_entry: ConfigEntry | None = None,
device_id: str | None = None,
entity_category: EntityCategory | None = None,
labels: set[str] | None = None,
original_device_class: str | None = None,
original_icon: str | None = None,
original_name: str | None = None,
Expand All @@ -349,6 +351,7 @@ def async_get_or_create(
config_entry_id=config_entry_id or UNDEFINED,
device_id=device_id or UNDEFINED,
entity_category=entity_category or UNDEFINED,
labels=labels or UNDEFINED,
original_device_class=original_device_class or UNDEFINED,
original_icon=original_icon or UNDEFINED,
original_name=original_name or UNDEFINED,
Expand Down Expand Up @@ -391,6 +394,7 @@ def async_get_or_create(
entity_category=entity_category,
entity_id=entity_id,
hidden_by=hidden_by,
labels=labels or set(),
original_device_class=original_device_class,
original_icon=original_icon,
original_name=original_name,
Expand Down Expand Up @@ -497,6 +501,7 @@ def async_update_entity(
entity_category: EntityCategory | None | UndefinedType = UNDEFINED,
hidden_by: RegistryEntryHider | None | UndefinedType = UNDEFINED,
icon: str | None | UndefinedType = UNDEFINED,
labels: set[str] | UndefinedType = UNDEFINED,
name: str | None | UndefinedType = UNDEFINED,
new_entity_id: str | UndefinedType = UNDEFINED,
new_unique_id: str | UndefinedType = UNDEFINED,
Expand Down Expand Up @@ -538,6 +543,7 @@ def async_update_entity(
("entity_category", entity_category),
("hidden_by", hidden_by),
("icon", icon),
("labels", labels),
("name", name),
("original_device_class", original_device_class),
("original_icon", original_icon),
Expand Down Expand Up @@ -655,6 +661,7 @@ async def async_load(self) -> None:
hidden_by=entity["hidden_by"],
icon=entity["icon"],
id=entity["id"],
labels=set(entity["labels"]),
name=entity["name"],
options=entity["options"],
original_device_class=entity["original_device_class"],
Expand Down Expand Up @@ -691,6 +698,7 @@ def _data_to_save(self) -> dict[str, Any]:
"hidden_by": entry.hidden_by,
"icon": entry.icon,
"id": entry.id,
"labels": list(entry.labels),
"name": entry.name,
"options": entry.options,
"original_device_class": entry.original_device_class,
Expand Down Expand Up @@ -723,6 +731,15 @@ def async_clear_area_id(self, area_id: str) -> None:
if area_id == entry.area_id:
self.async_update_entity(entity_id, area_id=None)

@callback
def async_clear_label_id(self, label_id: str) -> None:
"""Clear label from registry entries."""
for entity_id, entry in self.entities.items():
if label_id in entry.labels:
labels = entry.labels.copy()
labels.remove(label_id)
self.async_update_entity(entity_id, labels=labels)


@callback
def async_get(hass: HomeAssistant) -> EntityRegistry:
Expand Down Expand Up @@ -767,6 +784,14 @@ def async_entries_for_area(
return [entry for entry in registry.entities.values() if entry.area_id == area_id]


@callback
def async_entries_for_label(
registry: EntityRegistry, label_id: str
) -> list[RegistryEntry]:
"""Return entries that match an label."""
return [entry for entry in registry.entities.values() if label_id in entry.labels]


@callback
def async_entries_for_config_entry(
registry: EntityRegistry, config_entry_id: str
Expand Down Expand Up @@ -854,6 +879,11 @@ async def _async_migrate(
for entity in data["entities"]:
entity["hidden_by"] = None

if old_major_version == 1 and old_minor_version < 7:
# Version 1.7 adds labels
for entity in data["entities"]:
entity["labels"] = []

if old_major_version > 1:
raise NotImplementedError
return data
Expand Down
66 changes: 66 additions & 0 deletions tests/helpers/test_entity_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def test_get_or_create_suggested_object_id(registry):

def test_get_or_create_updates_data(registry):
"""Test that we update data in get_or_create."""
# TODO Add label stuff
orig_config_entry = MockConfigEntry(domain="light")

orig_entry = registry.async_get_or_create(
Expand Down Expand Up @@ -178,6 +179,7 @@ def test_create_triggers_save(hass, registry):

async def test_loading_saving_data(hass, registry):
"""Test that we load/save data correctly."""
# TODO Add label stuff
mock_config = MockConfigEntry(domain="light")

orig_entry1 = registry.async_get_or_create("light", "hue", "1234")
Expand Down Expand Up @@ -383,6 +385,7 @@ async def test_removing_config_entry_id(hass, registry, update_events):

async def test_removing_area_id(registry):
"""Make sure we can clear area id."""
# TODO Duplicate for labels
entry = registry.async_get_or_create("light", "hue", "5678")

entry_w_area = registry.async_update_entity(entry.entity_id, area_id="12345A")
Expand Down Expand Up @@ -1247,3 +1250,66 @@ async def test_entity_category_str_not_allowed(hass):
reg.async_update_entity(
entity_id, entity_category=EntityCategory.DIAGNOSTIC.value
)


async def test_removing_labels(registry: er.EntityRegistry) -> None:
"""Make sure we can clear labels."""
entry = registry.async_get_or_create(
domain="light",
platform="hue",
unique_id="5678",
labels={"label1", "label2"},
)

registry.async_clear_label_id("label1")
entry_cleared_label1 = registry.async_get(entry.entity_id)

registry.async_clear_label_id("label2")
entry_cleared_label2 = registry.async_get(entry.entity_id)

assert entry_cleared_label1
assert entry_cleared_label2
assert entry != entry_cleared_label1
assert entry != entry_cleared_label2
assert entry_cleared_label1 != entry_cleared_label2
assert entry.labels == {"label1", "label2"}
assert entry_cleared_label1.labels == {"label2"}
assert not entry_cleared_label2.labels


async def test_entries_for_label(registry: er.EntityRegistry) -> None:
"""Test getting entity entries by label."""
registry.async_get_or_create(
domain="light",
platform="hue",
unique_id="000",
)
label_1 = registry.async_get_or_create(
domain="light",
platform="hue",
unique_id="123",
labels={"label1"},
)
label_2 = registry.async_get_or_create(
domain="light",
platform="hue",
unique_id="456",
labels={"label2"},
)
label_1_and_2 = registry.async_get_or_create(
domain="light",
platform="hue",
unique_id="789",
labels={"label1", "label2"},
)

entries = er.async_entries_for_label(registry, "label1")
assert len(entries) == 2
assert entries == [label_1, label_1_and_2]

entries = er.async_entries_for_label(registry, "label2")
assert len(entries) == 2
assert entries == [label_2, label_1_and_2]

assert not er.async_entries_for_label(registry, "unknown")
assert not er.async_entries_for_label(registry, "")

0 comments on commit 6b7be04

Please sign in to comment.