Make the switch keep-alive callback conditional on the entity state (#382)

This commit is contained in:
Paulo Ferreira de Castro
2024-02-16 06:23:30 +00:00
committed by GitHub
parent a440b35815
commit dce8fa2ed6
2 changed files with 11 additions and 5 deletions

View File

@@ -220,9 +220,7 @@ class UnderlyingSwitch(UnderlyingEntity):
@overrides
def startup(self):
super().startup()
self._keep_alive.set_async_action(
self.turn_on if self.is_device_active else self.turn_off
)
self._keep_alive.set_async_action(self._keep_alive_callback)
# @overrides this breaks some unit tests TypeError: object MagicMock can't be used in 'await' expression
async def set_hvac_mode(self, hvac_mode: HVACMode) -> bool:
@@ -247,9 +245,14 @@ class UnderlyingSwitch(UnderlyingEntity):
not self.is_inversed and real_state
)
async def _keep_alive_callback(self):
"""Keep alive: Turn on if already turned on, turn off if already turned off."""
await (self.turn_on() if self.is_device_active else self.turn_off())
# @overrides this breaks some unit tests TypeError: object MagicMock can't be used in 'await' expression
async def turn_off(self):
"""Turn heater toggleable device off."""
self._keep_alive.cancel() # Cancel early to avoid a turn_on/turn_off race condition
_LOGGER.debug("%s - Stopping underlying entity %s", self, self._entity_id)
command = SERVICE_TURN_OFF if not self.is_inversed else SERVICE_TURN_ON
domain = self._entity_id.split(".")[0]
@@ -258,7 +261,7 @@ class UnderlyingSwitch(UnderlyingEntity):
try:
data = {ATTR_ENTITY_ID: self._entity_id}
await self._hass.services.async_call(domain, command, data)
self._keep_alive.set_async_action(self.turn_off)
self._keep_alive.set_async_action(self._keep_alive_callback)
except Exception:
self._keep_alive.cancel()
raise
@@ -267,6 +270,7 @@ class UnderlyingSwitch(UnderlyingEntity):
async def turn_on(self):
"""Turn heater toggleable device on."""
self._keep_alive.cancel() # Cancel early to avoid a turn_on/turn_off race condition
_LOGGER.debug("%s - Starting underlying entity %s", self, self._entity_id)
command = SERVICE_TURN_ON if not self.is_inversed else SERVICE_TURN_OFF
domain = self._entity_id.split(".")[0]
@@ -274,7 +278,7 @@ class UnderlyingSwitch(UnderlyingEntity):
try:
data = {ATTR_ENTITY_ID: self._entity_id}
await self._hass.services.async_call(domain, command, data)
self._keep_alive.set_async_action(self.turn_on)
self._keep_alive.set_async_action(self._keep_alive_callback)
except Exception:
self._keep_alive.cancel()
raise

View File

@@ -210,6 +210,7 @@ class TestKeepAlive:
common_mocks,
[call("switch", SERVICE_TURN_ON, {"entity_id": "switch.mock_switch"})],
)
common_mocks.mock_is_state.return_value = True
# Call the keep-alive callback a few times (as if `async_track_time_interval`
# had done it) and assert that the callback function is replaced each time.
@@ -240,6 +241,7 @@ class TestKeepAlive:
common_mocks,
[call("switch", SERVICE_TURN_OFF, {"entity_id": "switch.mock_switch"})],
)
common_mocks.mock_is_state.return_value = False
# Call the keep-alive callback a few times (as if `async_track_time_interval`
# had done it) and assert that the callback function is replaced each time.