Skip to content

Commit

Permalink
Allow targeting labels in service calls
Browse files Browse the repository at this point in the history
  • Loading branch information
frenck committed Jul 7, 2022
1 parent 12cab88 commit 5959029
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 4 deletions.
3 changes: 3 additions & 0 deletions homeassistant/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,9 @@ class Platform(StrEnum):
# Contains one string, the device ID
ATTR_DEVICE_ID: Final = "device_id"

# Contains one string or a list of strings, each being an label id
ATTR_LABEL_ID: Final = "label_id"

# String with a friendly name for the entity
ATTR_FRIENDLY_NAME: Final = "friendly_name"

Expand Down
7 changes: 7 additions & 0 deletions homeassistant/helpers/config_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
ATTR_AREA_ID,
ATTR_DEVICE_ID,
ATTR_ENTITY_ID,
ATTR_LABEL_ID,
CONF_ABOVE,
CONF_ALIAS,
CONF_ATTRIBUTE,
Expand Down Expand Up @@ -1032,6 +1033,9 @@ def expand_condition_shorthand(value: Any | None) -> Any:
vol.Optional(ATTR_AREA_ID): vol.Any(
ENTITY_MATCH_NONE, vol.All(ensure_list, [vol.Any(dynamic_template, str)])
),
vol.Optional(ATTR_LABEL_ID): vol.Any(
ENTITY_MATCH_NONE, vol.All(ensure_list, [vol.Any(dynamic_template, str)])
),
}

TARGET_SERVICE_FIELDS = {
Expand All @@ -1049,6 +1053,9 @@ def expand_condition_shorthand(value: Any | None) -> Any:
vol.Optional(ATTR_AREA_ID): vol.Any(
ENTITY_MATCH_NONE, vol.All(ensure_list, [vol.Any(dynamic_template, str)])
),
vol.Optional(ATTR_LABEL_ID): vol.Any(
ENTITY_MATCH_NONE, vol.All(ensure_list, [vol.Any(dynamic_template, str)])
),
}


Expand Down
41 changes: 37 additions & 4 deletions homeassistant/helpers/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
ATTR_AREA_ID,
ATTR_DEVICE_ID,
ATTR_ENTITY_ID,
ATTR_LABEL_ID,
CONF_ENTITY_ID,
CONF_SERVICE,
CONF_SERVICE_DATA,
Expand Down Expand Up @@ -46,6 +47,7 @@
config_validation as cv,
device_registry,
entity_registry,
label_registry,
template,
)
from .typing import ConfigType, TemplateVarsType
Expand Down Expand Up @@ -80,6 +82,7 @@ def __init__(self, service_call: ServiceCall) -> None:
entity_ids: str | list | None = service_call.data.get(ATTR_ENTITY_ID)
device_ids: str | list | None = service_call.data.get(ATTR_DEVICE_ID)
area_ids: str | list | None = service_call.data.get(ATTR_AREA_ID)
label_ids: str | list | None = service_call.data.get(ATTR_LABEL_ID)

self.entity_ids = (
set(cv.ensure_list(entity_ids)) if _has_match(entity_ids) else set()
Expand All @@ -88,11 +91,16 @@ def __init__(self, service_call: ServiceCall) -> None:
set(cv.ensure_list(device_ids)) if _has_match(device_ids) else set()
)
self.area_ids = set(cv.ensure_list(area_ids)) if _has_match(area_ids) else set()
self.label_ids = (
set(cv.ensure_list(label_ids)) if _has_match(label_ids) else set()
)

@property
def has_any_selector(self) -> bool:
"""Determine if any selectors are present."""
return bool(self.entity_ids or self.device_ids or self.area_ids)
return bool(
self.entity_ids or self.device_ids or self.area_ids or self.label_ids
)


@dataclasses.dataclass
Expand All @@ -102,13 +110,14 @@ class SelectedEntities:
# Entities that were explicitly mentioned.
referenced: set[str] = dataclasses.field(default_factory=set)

# Entities that were referenced via device/area ID.
# Entities that were referenced via device/area/label ID.
# Should not trigger a warning when they don't exist.
indirectly_referenced: set[str] = dataclasses.field(default_factory=set)

# Referenced items that could not be found.
missing_devices: set[str] = dataclasses.field(default_factory=set)
missing_areas: set[str] = dataclasses.field(default_factory=set)
missing_labels: set[str] = dataclasses.field(default_factory=set)

# Referenced devices
referenced_devices: set[str] = dataclasses.field(default_factory=set)
Expand All @@ -120,6 +129,7 @@ def log_missing(self, missing_entities: set[str]) -> None:
("areas", self.missing_areas),
("devices", self.missing_devices),
("entities", missing_entities),
("labels", self.missing_labels),
):
if items:
parts.append(f"{label} {', '.join(sorted(items))}")
Expand Down Expand Up @@ -350,12 +360,13 @@ def async_extract_referenced_entity_ids(

selected.referenced.update(entity_ids)

if not selector.device_ids and not selector.area_ids:
if not selector.device_ids and not selector.area_ids and not selector.label_ids:
return selected

ent_reg = entity_registry.async_get(hass)
dev_reg = device_registry.async_get(hass)
area_reg = area_registry.async_get(hass)
label_reg = label_registry.async_get(hass)

for device_id in selector.device_ids:
if device_id not in dev_reg.devices:
Expand All @@ -365,13 +376,27 @@ def async_extract_referenced_entity_ids(
if area_id not in area_reg.areas:
selected.missing_areas.add(area_id)

for label_id in selector.label_ids:
if label_id not in label_reg.labels:
selected.missing_labels.add(label_id)

# Find devices for targeted areas
selected.referenced_devices.update(selector.device_ids)
for device_entry in dev_reg.devices.values():
if device_entry.area_id in selector.area_ids:
selected.referenced_devices.add(device_entry.id)

if not selector.area_ids and not selected.referenced_devices:
# Find devices for targeted areas
selected.referenced_devices.update(selector.device_ids)
for device_entry in dev_reg.devices.values():
if device_entry.labels.intersection(selector.label_ids):
selected.referenced_devices.add(device_entry.id)

if (
not selector.area_ids
and not selector.label_ids
and not selected.referenced_devices
):
return selected

for ent_entry in ent_reg.entities.values():
Expand All @@ -390,6 +415,14 @@ def async_extract_referenced_entity_ids(
)
# The entity's device matches a targeted device
or ent_entry.device_id in selector.device_ids
# The entity's label matches a targeted label
or ent_entry.labels.intersection(selector.label_ids)
# The entity's device matches a device referenced by an label and
# the entity has no explicitly set labels
or (
not ent_entry.labels
and ent_entry.device_id in selected.referenced_devices
)
):
selected.indirectly_referenced.add(ent_entry.entity_id)

Expand Down

0 comments on commit 5959029

Please sign in to comment.