diff --git a/homeassistant/components/threshold/binary_sensor.py b/homeassistant/components/threshold/binary_sensor.py index 6382c79b9ce87..96ccafab6d6be 100644 --- a/homeassistant/components/threshold/binary_sensor.py +++ b/homeassistant/components/threshold/binary_sensor.py @@ -1,6 +1,7 @@ """Support for monitoring if a sensor value is below/above a threshold.""" from __future__ import annotations +from collections.abc import Callable, Mapping import logging from typing import Any @@ -21,7 +22,7 @@ STATE_UNAVAILABLE, STATE_UNKNOWN, ) -from homeassistant.core import HomeAssistant, callback +from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback from homeassistant.helpers import ( config_validation as cv, device_registry as dr, @@ -113,7 +114,6 @@ async def async_setup_entry( async_add_entities( [ ThresholdSensor( - hass, entity_id, name, lower, @@ -147,7 +147,7 @@ async def async_setup_platform( async_add_entities( [ ThresholdSensor( - hass, entity_id, name, lower, upper, hysteresis, device_class, None + entity_id, name, lower, upper, hysteresis, device_class, None ) ], ) @@ -169,7 +169,6 @@ class ThresholdSensor(BinarySensorEntity): def __init__( self, - hass: HomeAssistant, entity_id: str, name: str, lower: float | None, @@ -180,6 +179,7 @@ def __init__( device_info: DeviceInfo | None = None, ) -> None: """Initialize the Threshold sensor.""" + self._preview_callback: Callable[[str, Mapping[str, Any]], None] | None = None self._attr_unique_id = unique_id self._attr_device_info = device_info self._entity_id = entity_id @@ -195,9 +195,17 @@ def __init__( self._state: bool | None = None self.sensor_value: float | None = None + async def async_added_to_hass(self) -> None: + """Run when entity about to be added to hass.""" + self._async_setup_sensor() + + @callback + def _async_setup_sensor(self) -> None: + """Set up the sensor and start tracking state changes.""" + def _update_sensor_state() -> None: """Handle sensor state changes.""" - if (new_state := hass.states.get(self._entity_id)) is None: + if (new_state := self.hass.states.get(self._entity_id)) is None: return try: @@ -212,17 +220,26 @@ def _update_sensor_state() -> None: self._update_state() + if self._preview_callback: + calculated_state = self._async_calculate_state() + self._preview_callback( + calculated_state.state, calculated_state.attributes + ) + @callback def async_threshold_sensor_state_listener( event: EventType[EventStateChangedData], ) -> None: """Handle sensor state changes.""" _update_sensor_state() - self.async_write_ha_state() + + # only write state to the state machine if we are not in preview mode + if not self._preview_callback: + self.async_write_ha_state() self.async_on_remove( async_track_state_change_event( - hass, [entity_id], async_threshold_sensor_state_listener + self.hass, [self._entity_id], async_threshold_sensor_state_listener ) ) _update_sensor_state() @@ -262,6 +279,14 @@ def above(sensor_value: float, threshold: float) -> bool: self._state = None return + # guard against the case where the thresholds are not set + if not hasattr(self, "_threshold_lower") and not hasattr( + self, "_threshold_upper" + ): + self._state_position = POSITION_UNKNOWN + self._state = None + return + if self.threshold_type == TYPE_LOWER: if self._state is None: self._state = False @@ -307,3 +332,22 @@ def above(sensor_value: float, threshold: float) -> bool: self._state_position = POSITION_IN_RANGE self._state = True return + + @callback + def async_start_preview( + self, + preview_callback: Callable[[str, Mapping[str, Any]], None], + ) -> CALLBACK_TYPE: + """Render a preview.""" + # abort early if there is no entity_id + # as without we can't track changes + if not self._entity_id: + self._attr_available = False + calculated_state = self._async_calculate_state() + preview_callback(calculated_state.state, calculated_state.attributes) + return self._call_on_remove_callbacks + + self._preview_callback = preview_callback + + self._async_setup_sensor() + return self._call_on_remove_callbacks diff --git a/homeassistant/components/threshold/config_flow.py b/homeassistant/components/threshold/config_flow.py index 31d51fee3f37d..4c0d255c0bf86 100644 --- a/homeassistant/components/threshold/config_flow.py +++ b/homeassistant/components/threshold/config_flow.py @@ -6,8 +6,11 @@ import voluptuous as vol +from homeassistant.components import websocket_api from homeassistant.components.sensor import DOMAIN as SENSOR_DOMAIN from homeassistant.const import CONF_ENTITY_ID, CONF_NAME +from homeassistant.core import HomeAssistant, callback +from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import selector from homeassistant.helpers.schema_config_entry_flow import ( SchemaCommonFlowHandler, @@ -16,6 +19,7 @@ SchemaFlowFormStep, ) +from .binary_sensor import ThresholdSensor from .const import CONF_HYSTERESIS, CONF_LOWER, CONF_UPPER, DEFAULT_HYSTERESIS, DOMAIN @@ -60,11 +64,15 @@ async def _validate_mode( ).extend(OPTIONS_SCHEMA.schema) CONFIG_FLOW = { - "user": SchemaFlowFormStep(CONFIG_SCHEMA, validate_user_input=_validate_mode) + "user": SchemaFlowFormStep( + CONFIG_SCHEMA, preview="threshold", validate_user_input=_validate_mode + ) } OPTIONS_FLOW = { - "init": SchemaFlowFormStep(OPTIONS_SCHEMA, validate_user_input=_validate_mode) + "init": SchemaFlowFormStep( + OPTIONS_SCHEMA, preview="threshold", validate_user_input=_validate_mode + ) } @@ -78,3 +86,61 @@ def async_config_entry_title(self, options: Mapping[str, Any]) -> str: """Return config entry title.""" name: str = options[CONF_NAME] return name + + @staticmethod + async def async_setup_preview(hass: HomeAssistant) -> None: + """Set up preview WS API.""" + websocket_api.async_register_command(hass, ws_start_preview) + + +@websocket_api.websocket_command( + { + vol.Required("type"): "threshold/start_preview", + vol.Required("flow_id"): str, + vol.Required("flow_type"): vol.Any("config_flow", "options_flow"), + vol.Required("user_input"): dict, + } +) +@callback +def ws_start_preview( + hass: HomeAssistant, + connection: websocket_api.ActiveConnection, + msg: dict[str, Any], +) -> None: + """Generate a preview.""" + + if msg["flow_type"] == "config_flow": + entity_id = msg["user_input"][CONF_ENTITY_ID] + name = msg["user_input"][CONF_NAME] + else: + flow_status = hass.config_entries.options.async_get(msg["flow_id"]) + config_entry = hass.config_entries.async_get_entry(flow_status["handler"]) + if not config_entry: + raise HomeAssistantError("Config entry not found") + entity_id = config_entry.options[CONF_ENTITY_ID] + name = config_entry.options[CONF_NAME] + + @callback + def async_preview_updated(state: str, attributes: Mapping[str, Any]) -> None: + """Forward config entry state events to websocket.""" + connection.send_message( + websocket_api.event_message( + msg["id"], {"attributes": attributes, "state": state} + ) + ) + + preview_entity = ThresholdSensor( + entity_id, + name, + msg["user_input"].get(CONF_LOWER), + msg["user_input"].get(CONF_UPPER), + msg["user_input"].get(CONF_HYSTERESIS), + None, + None, + ) + preview_entity.hass = hass + + connection.send_result(msg["id"]) + connection.subscriptions[msg["id"]] = preview_entity.async_start_preview( + async_preview_updated + ) diff --git a/tests/components/threshold/snapshots/test_config_flow.ambr b/tests/components/threshold/snapshots/test_config_flow.ambr new file mode 100644 index 0000000000000..e7f97386903a5 --- /dev/null +++ b/tests/components/threshold/snapshots/test_config_flow.ambr @@ -0,0 +1,54 @@ +# serializer version: 1 +# name: test_config_flow_preview_success[missing_entity_id] + dict({ + 'attributes': dict({ + 'friendly_name': '', + }), + 'state': 'unavailable', + }) +# --- +# name: test_config_flow_preview_success[missing_upper_lower] + dict({ + 'attributes': dict({ + 'entity_id': 'sensor.test_monitored', + 'friendly_name': 'Test Sensor', + 'hysteresis': 0.0, + 'lower': None, + 'position': 'unknown', + 'sensor_value': 16.0, + 'type': 'upper', + 'upper': None, + }), + 'state': 'unknown', + }) +# --- +# name: test_config_flow_preview_success[success] + dict({ + 'attributes': dict({ + 'entity_id': 'sensor.test_monitored', + 'friendly_name': 'Test Sensor', + 'hysteresis': 0.0, + 'lower': 20.0, + 'position': 'below', + 'sensor_value': 16.0, + 'type': 'lower', + 'upper': None, + }), + 'state': 'on', + }) +# --- +# name: test_options_flow_preview + dict({ + 'attributes': dict({ + 'entity_id': 'sensor.test_monitored', + 'friendly_name': 'Test Sensor', + 'hysteresis': 0.0, + 'lower': 20.0, + 'position': 'below', + 'sensor_value': 16.0, + 'type': 'lower', + 'upper': None, + }), + 'state': 'on', + }) +# --- diff --git a/tests/components/threshold/test_config_flow.py b/tests/components/threshold/test_config_flow.py index 8229ed1b1efaa..81d7e43ae136f 100644 --- a/tests/components/threshold/test_config_flow.py +++ b/tests/components/threshold/test_config_flow.py @@ -2,13 +2,16 @@ from unittest.mock import patch import pytest +from syrupy import SnapshotAssertion from homeassistant import config_entries from homeassistant.components.threshold.const import DOMAIN +from homeassistant.const import ATTR_UNIT_OF_MEASUREMENT, UnitOfTemperature from homeassistant.core import HomeAssistant from homeassistant.data_entry_flow import FlowResultType from tests.common import MockConfigEntry +from tests.typing import WebSocketGenerator async def test_config_flow(hass: HomeAssistant) -> None: @@ -160,3 +163,183 @@ async def test_options(hass: HomeAssistant) -> None: state = hass.states.get("binary_sensor.my_threshold") assert state.state == "off" assert state.attributes["type"] == "upper" + + +@pytest.mark.parametrize( + "user_input", + ( + ( + { + "name": "Test Sensor", + "entity_id": "sensor.test_monitored", + "hysteresis": 0.0, + "lower": 20.0, + } + ), + ( + { + "name": "Test Sensor", + "entity_id": "sensor.test_monitored", + "hysteresis": 0.0, + } + ), + ( + { + "name": "", + "entity_id": "", + "hysteresis": 0.0, + "lower": 20.0, + } + ), + ), + ids=("success", "missing_upper_lower", "missing_entity_id"), +) +async def test_config_flow_preview_success( + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + user_input: str, + snapshot: SnapshotAssertion, +) -> None: + """Test the config flow preview.""" + client = await hass_ws_client(hass) + + # add state for the tests + hass.states.async_set( + "sensor.test_monitored", + 16, + {ATTR_UNIT_OF_MEASUREMENT: UnitOfTemperature.CELSIUS}, + ) + + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + assert result["type"] == FlowResultType.FORM + assert result["step_id"] == "user" + assert result["errors"] is None + assert result["preview"] == "threshold" + + await client.send_json_auto_id( + { + "type": "threshold/start_preview", + "flow_id": result["flow_id"], + "flow_type": "config_flow", + "user_input": user_input, + } + ) + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] is None + + msg = await client.receive_json() + assert msg["event"] == snapshot + assert len(hass.states.async_all()) == 1 + + +async def test_options_flow_preview( + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + snapshot: SnapshotAssertion, +) -> None: + """Test the options flow preview.""" + client = await hass_ws_client(hass) + + # add state for the tests + hass.states.async_set( + "sensor.test_monitored", + 16, + {ATTR_UNIT_OF_MEASUREMENT: UnitOfTemperature.CELSIUS}, + ) + + # Setup the config entry + config_entry = MockConfigEntry( + data={}, + domain=DOMAIN, + options={ + "entity_id": "sensor.test_monitored", + "hysteresis": 0.0, + "lower": 20.0, + "name": "Test Sensor", + "upper": None, + }, + title="Test Sensor", + ) + config_entry.add_to_hass(hass) + assert await hass.config_entries.async_setup(config_entry.entry_id) + await hass.async_block_till_done() + + result = await hass.config_entries.options.async_init(config_entry.entry_id) + assert result["type"] == FlowResultType.FORM + assert result["errors"] is None + assert result["preview"] == "threshold" + + await client.send_json_auto_id( + { + "type": "threshold/start_preview", + "flow_id": result["flow_id"], + "flow_type": "options_flow", + "user_input": { + "name": "Test Sensor", + "entity_id": "sensor.test_monitored", + "hysteresis": 0.0, + "lower": 20.0, + }, + } + ) + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] is None + + msg = await client.receive_json() + assert msg["event"] == snapshot + assert len(hass.states.async_all()) == 2 + + +async def test_options_flow_sensor_preview_config_entry_removed( + hass: HomeAssistant, hass_ws_client: WebSocketGenerator +) -> None: + """Test the option flow preview where the config entry is removed.""" + client = await hass_ws_client(hass) + + # Setup the config entry + config_entry = MockConfigEntry( + data={}, + domain=DOMAIN, + options={ + "entity_id": "sensor.test_monitored", + "hysteresis": 0.0, + "lower": 20.0, + "name": "Test Sensor", + "upper": None, + }, + title="Test Sensor", + ) + config_entry.add_to_hass(hass) + assert await hass.config_entries.async_setup(config_entry.entry_id) + await hass.async_block_till_done() + + result = await hass.config_entries.options.async_init(config_entry.entry_id) + assert result["type"] == FlowResultType.FORM + assert result["errors"] is None + assert result["preview"] == "threshold" + + await hass.config_entries.async_remove(config_entry.entry_id) + + await client.send_json_auto_id( + { + "type": "threshold/start_preview", + "flow_id": result["flow_id"], + "flow_type": "options_flow", + "user_input": { + "name": "Test Sensor", + "entity_id": "sensor.test_monitored", + "hysteresis": 0.0, + "lower": 20.0, + }, + } + ) + msg = await client.receive_json() + assert not msg["success"] + assert msg["error"] == { + "code": "home_assistant_error", + "message": "Config entry not found", + }