Skip to content

Commit

Permalink
Add label support to entity registry
Browse files Browse the repository at this point in the history
  • Loading branch information
frenck committed Sep 17, 2022
1 parent 883d36c commit b2f0cdd
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 17 deletions.
51 changes: 40 additions & 11 deletions homeassistant/helpers/entity_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
_LOGGER = logging.getLogger(__name__)

STORAGE_VERSION_MAJOR = 1
STORAGE_VERSION_MINOR = 8
STORAGE_VERSION_MINOR = 9
STORAGE_KEY = "core.entity_registry"

# Attributes relevant to describing entity
Expand Down Expand Up @@ -113,6 +113,7 @@ class RegistryEntry:
icon: str | None = attr.ib(default=None)
id: str = attr.ib(factory=uuid_util.random_uuid_hex)
has_entity_name: bool = attr.ib(default=False)
labels: set[str] = attr.ib(factory=set)
name: str | None = attr.ib(default=None)
options: Mapping[str, Mapping[str, Any]] = attr.ib(
default=None, converter=attr.converters.default_if_none(factory=dict) # type: ignore[misc]
Expand Down Expand Up @@ -231,6 +232,11 @@ async def _async_migrate_func(
continue
entity["device_class"] = None

if old_major_version == 1 and old_minor_version < 9:
# Version 1.9 adds labels
for entity in data["entities"]:
entity["labels"] = []

if old_major_version > 1:
raise NotImplementedError
return data
Expand Down Expand Up @@ -563,19 +569,20 @@ def _async_update_entity(
device_id: str | None | UndefinedType = UNDEFINED,
disabled_by: RegistryEntryDisabler | None | UndefinedType = UNDEFINED,
entity_category: EntityCategory | None | UndefinedType = UNDEFINED,
has_entity_name: bool | UndefinedType = UNDEFINED,
hidden_by: RegistryEntryHider | None | UndefinedType = UNDEFINED,
icon: str | None | UndefinedType = UNDEFINED,
has_entity_name: bool | UndefinedType = UNDEFINED,
labels: set[str] | UndefinedType = UNDEFINED,
name: str | None | UndefinedType = UNDEFINED,
new_entity_id: str | UndefinedType = UNDEFINED,
new_unique_id: str | UndefinedType = UNDEFINED,
options: Mapping[str, Mapping[str, Any]] | UndefinedType = UNDEFINED,
original_device_class: str | None | UndefinedType = UNDEFINED,
original_icon: str | None | UndefinedType = UNDEFINED,
original_name: str | None | UndefinedType = UNDEFINED,
platform: str | None | UndefinedType = UNDEFINED,
supported_features: int | UndefinedType = UNDEFINED,
unit_of_measurement: str | None | UndefinedType = UNDEFINED,
platform: str | None | UndefinedType = UNDEFINED,
options: Mapping[str, Mapping[str, Any]] | UndefinedType = UNDEFINED,
) -> RegistryEntry:
"""Private facing update properties method."""
old = self.entities[entity_id]
Expand Down Expand Up @@ -613,17 +620,18 @@ def _async_update_entity(
("device_id", device_id),
("disabled_by", disabled_by),
("entity_category", entity_category),
("has_entity_name", has_entity_name),
("hidden_by", hidden_by),
("icon", icon),
("has_entity_name", has_entity_name),
("labels", labels),
("name", name),
("options", options),
("original_device_class", original_device_class),
("original_icon", original_icon),
("original_name", original_name),
("platform", platform),
("supported_features", supported_features),
("unit_of_measurement", unit_of_measurement),
("platform", platform),
("options", options),
):
if value is not UNDEFINED and value != getattr(old, attr_name):
new_values[attr_name] = value
Expand Down Expand Up @@ -687,9 +695,10 @@ def async_update_entity(
device_id: str | None | UndefinedType = UNDEFINED,
disabled_by: RegistryEntryDisabler | None | UndefinedType = UNDEFINED,
entity_category: EntityCategory | None | UndefinedType = UNDEFINED,
has_entity_name: bool | UndefinedType = UNDEFINED,
hidden_by: RegistryEntryHider | None | UndefinedType = UNDEFINED,
icon: str | None | UndefinedType = UNDEFINED,
has_entity_name: bool | UndefinedType = UNDEFINED,
labels: set[str] | UndefinedType = UNDEFINED,
name: str | None | UndefinedType = UNDEFINED,
new_entity_id: str | UndefinedType = UNDEFINED,
new_unique_id: str | UndefinedType = UNDEFINED,
Expand All @@ -709,9 +718,10 @@ def async_update_entity(
device_id=device_id,
disabled_by=disabled_by,
entity_category=entity_category,
has_entity_name=has_entity_name,
hidden_by=hidden_by,
icon=icon,
has_entity_name=has_entity_name,
labels=labels,
name=name,
new_entity_id=new_entity_id,
new_unique_id=new_unique_id,
Expand Down Expand Up @@ -795,12 +805,13 @@ async def async_load(self) -> None:
if entity["entity_category"]
else None,
entity_id=entity["entity_id"],
has_entity_name=entity["has_entity_name"],
hidden_by=RegistryEntryHider(entity["hidden_by"])
if entity["hidden_by"]
else None,
icon=entity["icon"],
id=entity["id"],
has_entity_name=entity["has_entity_name"],
labels=set(entity["labels"]),
name=entity["name"],
options=entity["options"],
original_device_class=entity["original_device_class"],
Expand Down Expand Up @@ -834,10 +845,11 @@ def _data_to_save(self) -> dict[str, Any]:
"disabled_by": entry.disabled_by,
"entity_category": entry.entity_category,
"entity_id": entry.entity_id,
"has_entity_name": entry.has_entity_name,
"hidden_by": entry.hidden_by,
"icon": entry.icon,
"id": entry.id,
"has_entity_name": entry.has_entity_name,
"labels": list(entry.labels),
"name": entry.name,
"options": entry.options,
"original_device_class": entry.original_device_class,
Expand Down Expand Up @@ -870,6 +882,15 @@ def async_clear_area_id(self, area_id: str) -> None:
if area_id == entry.area_id:
self.async_update_entity(entity_id, area_id=None)

@callback
def async_clear_label_id(self, label_id: str) -> None:
"""Clear label from registry entries."""
for entity_id, entry in self.entities.items():
if label_id in entry.labels:
labels = entry.labels.copy()
labels.remove(label_id)
self.async_update_entity(entity_id, labels=labels)


@callback
def async_get(hass: HomeAssistant) -> EntityRegistry:
Expand Down Expand Up @@ -917,6 +938,14 @@ def async_entries_for_area(
return [entry for entry in registry.entities.values() if entry.area_id == area_id]


@callback
def async_entries_for_label(
registry: EntityRegistry, label_id: str
) -> list[RegistryEntry]:
"""Return entries that match an label."""
return [entry for entry in registry.entities.values() if label_id in entry.labels]


@callback
def async_entries_for_config_entry(
registry: EntityRegistry, config_entry_id: str
Expand Down
8 changes: 3 additions & 5 deletions homeassistant/helpers/label_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,9 @@ def async_delete(self, label_id: str) -> None:
"""Delete label."""
label = self.labels[label_id]

# Clean up all references (TODO)
device_registry = dr.async_get(self.hass)
device_registry.async_clear_label_id(label_id)
_ = er.async_get(self.hass)
# entity_registry.async_clear_label_id(label_id)
# Clean up all references
dr.async_get(self.hass).async_clear_label_id(label_id)
er.async_get(self.hass).async_clear_label_id(label_id)

del self.labels[label_id]
del self._normalized_name_label_idx[label.normalized_name]
Expand Down
67 changes: 67 additions & 0 deletions tests/helpers/test_entity_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ async def test_loading_saving_data(hass, registry):
registry.async_update_entity_options(
orig_entry2.entity_id, "light", {"minimum_brightness": 20}
)
registry.async_update_entity(orig_entry2.entity_id, labels={"label1", "label2"})
orig_entry2 = registry.async_get(orig_entry2.entity_id)

assert len(registry.entities) == 2
Expand Down Expand Up @@ -281,6 +282,7 @@ async def test_loading_saving_data(hass, registry):
assert new_entry2.icon == "hass:user-icon"
assert new_entry2.hidden_by == er.RegistryEntryHider.INTEGRATION
assert new_entry2.has_entity_name is True
assert new_entry2.labels == {"label1", "label2"}
assert new_entry2.name == "User Name"
assert new_entry2.options == {"light": {"minimum_brightness": 20}}
assert new_entry2.original_device_class == "mock-device-class"
Expand Down Expand Up @@ -1371,3 +1373,68 @@ def test_migrate_entity_to_new_platform(hass, registry):
new_unique_id=new_unique_id,
new_config_entry_id=new_config_entry.entry_id,
)


async def test_removing_labels(registry: er.EntityRegistry) -> None:
"""Make sure we can clear labels."""
entry = registry.async_get_or_create(
domain="light",
platform="hue",
unique_id="5678",
)
entry = registry.async_update_entity(entry.entity_id, labels={"label1", "label2"})

registry.async_clear_label_id("label1")
entry_cleared_label1 = registry.async_get(entry.entity_id)

registry.async_clear_label_id("label2")
entry_cleared_label2 = registry.async_get(entry.entity_id)

assert entry_cleared_label1
assert entry_cleared_label2
assert entry != entry_cleared_label1
assert entry != entry_cleared_label2
assert entry_cleared_label1 != entry_cleared_label2
assert entry.labels == {"label1", "label2"}
assert entry_cleared_label1.labels == {"label2"}
assert not entry_cleared_label2.labels


async def test_entries_for_label(registry: er.EntityRegistry) -> None:
"""Test getting entity entries by label."""
registry.async_get_or_create(
domain="light",
platform="hue",
unique_id="000",
)
entry = registry.async_get_or_create(
domain="light",
platform="hue",
unique_id="123",
)
label_1 = registry.async_update_entity(entry.entity_id, labels={"label1"})
entry = registry.async_get_or_create(
domain="light",
platform="hue",
unique_id="456",
)
label_2 = registry.async_update_entity(entry.entity_id, labels={"label2"})
entry = registry.async_get_or_create(
domain="light",
platform="hue",
unique_id="789",
)
label_1_and_2 = registry.async_update_entity(
entry.entity_id, labels={"label1", "label2"}
)

entries = er.async_entries_for_label(registry, "label1")
assert len(entries) == 2
assert entries == [label_1, label_1_and_2]

entries = er.async_entries_for_label(registry, "label2")
assert len(entries) == 2
assert entries == [label_2, label_1_and_2]

assert not er.async_entries_for_label(registry, "unknown")
assert not er.async_entries_for_label(registry, "")
55 changes: 54 additions & 1 deletion tests/helpers/test_label_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
import pytest

from homeassistant.core import HomeAssistant
from homeassistant.helpers import device_registry as dr, label_registry
from homeassistant.helpers import (
device_registry as dr,
entity_registry as er,
label_registry,
)
from homeassistant.helpers.label_registry import (
EVENT_LABEL_REGISTRY_UPDATED,
STORAGE_KEY,
Expand Down Expand Up @@ -387,3 +391,52 @@ async def test_labels_removed_from_devices(hass: HomeAssistant) -> None:
assert len(entries) == 0
entries = dr.async_entries_for_label(device_registry, label2.label_id)
assert len(entries) == 0


async def test_labels_removed_from_entities(hass: HomeAssistant) -> None:
"""Tests if label gets removed from entity when the label is removed."""
registry = label_registry.async_get(hass)
label1 = registry.async_create("label1")
label2 = registry.async_create("label2")
assert len(registry.labels) == 2

entity_registry = er.async_get(hass)
entry = entity_registry.async_get_or_create(
domain="light",
platform="hue",
unique_id="123",
)
entity_registry.async_update_entity(entry.entity_id, labels={label1.label_id})
entry = entity_registry.async_get_or_create(
domain="light",
platform="hue",
unique_id="456",
)
entity_registry.async_update_entity(entry.entity_id, labels={label2.label_id})
entry = entity_registry.async_get_or_create(
domain="light",
platform="hue",
unique_id="789",
)
entity_registry.async_update_entity(
entry.entity_id, labels={label1.label_id, label2.label_id}
)

entries = er.async_entries_for_label(entity_registry, label1.label_id)
assert len(entries) == 2
entries = er.async_entries_for_label(entity_registry, label2.label_id)
assert len(entries) == 2

registry.async_delete(label1.label_id)

entries = er.async_entries_for_label(entity_registry, label1.label_id)
assert len(entries) == 0
entries = er.async_entries_for_label(entity_registry, label2.label_id)
assert len(entries) == 2

registry.async_delete(label2.label_id)

entries = er.async_entries_for_label(entity_registry, label1.label_id)
assert len(entries) == 0
entries = er.async_entries_for_label(entity_registry, label2.label_id)
assert len(entries) == 0

0 comments on commit b2f0cdd

Please sign in to comment.