Make the switch keep-alive callback conditional on the entity state (#382)

This commit is contained in:
Paulo Ferreira de Castro
2024-02-16 06:23:30 +00:00
committed by GitHub
parent a440b35815
commit dce8fa2ed6
2 changed files with 11 additions and 5 deletions

View File

@@ -220,9 +220,7 @@ class UnderlyingSwitch(UnderlyingEntity):
@overrides @overrides
def startup(self): def startup(self):
super().startup() super().startup()
self._keep_alive.set_async_action( self._keep_alive.set_async_action(self._keep_alive_callback)
self.turn_on if self.is_device_active else self.turn_off
)
# @overrides this breaks some unit tests TypeError: object MagicMock can't be used in 'await' expression # @overrides this breaks some unit tests TypeError: object MagicMock can't be used in 'await' expression
async def set_hvac_mode(self, hvac_mode: HVACMode) -> bool: async def set_hvac_mode(self, hvac_mode: HVACMode) -> bool:
@@ -247,9 +245,14 @@ class UnderlyingSwitch(UnderlyingEntity):
not self.is_inversed and real_state not self.is_inversed and real_state
) )
async def _keep_alive_callback(self):
"""Keep alive: Turn on if already turned on, turn off if already turned off."""
await (self.turn_on() if self.is_device_active else self.turn_off())
# @overrides this breaks some unit tests TypeError: object MagicMock can't be used in 'await' expression # @overrides this breaks some unit tests TypeError: object MagicMock can't be used in 'await' expression
async def turn_off(self): async def turn_off(self):
"""Turn heater toggleable device off.""" """Turn heater toggleable device off."""
self._keep_alive.cancel() # Cancel early to avoid a turn_on/turn_off race condition
_LOGGER.debug("%s - Stopping underlying entity %s", self, self._entity_id) _LOGGER.debug("%s - Stopping underlying entity %s", self, self._entity_id)
command = SERVICE_TURN_OFF if not self.is_inversed else SERVICE_TURN_ON command = SERVICE_TURN_OFF if not self.is_inversed else SERVICE_TURN_ON
domain = self._entity_id.split(".")[0] domain = self._entity_id.split(".")[0]
@@ -258,7 +261,7 @@ class UnderlyingSwitch(UnderlyingEntity):
try: try:
data = {ATTR_ENTITY_ID: self._entity_id} data = {ATTR_ENTITY_ID: self._entity_id}
await self._hass.services.async_call(domain, command, data) await self._hass.services.async_call(domain, command, data)
self._keep_alive.set_async_action(self.turn_off) self._keep_alive.set_async_action(self._keep_alive_callback)
except Exception: except Exception:
self._keep_alive.cancel() self._keep_alive.cancel()
raise raise
@@ -267,6 +270,7 @@ class UnderlyingSwitch(UnderlyingEntity):
async def turn_on(self): async def turn_on(self):
"""Turn heater toggleable device on.""" """Turn heater toggleable device on."""
self._keep_alive.cancel() # Cancel early to avoid a turn_on/turn_off race condition
_LOGGER.debug("%s - Starting underlying entity %s", self, self._entity_id) _LOGGER.debug("%s - Starting underlying entity %s", self, self._entity_id)
command = SERVICE_TURN_ON if not self.is_inversed else SERVICE_TURN_OFF command = SERVICE_TURN_ON if not self.is_inversed else SERVICE_TURN_OFF
domain = self._entity_id.split(".")[0] domain = self._entity_id.split(".")[0]
@@ -274,7 +278,7 @@ class UnderlyingSwitch(UnderlyingEntity):
try: try:
data = {ATTR_ENTITY_ID: self._entity_id} data = {ATTR_ENTITY_ID: self._entity_id}
await self._hass.services.async_call(domain, command, data) await self._hass.services.async_call(domain, command, data)
self._keep_alive.set_async_action(self.turn_on) self._keep_alive.set_async_action(self._keep_alive_callback)
except Exception: except Exception:
self._keep_alive.cancel() self._keep_alive.cancel()
raise raise

View File

@@ -210,6 +210,7 @@ class TestKeepAlive:
common_mocks, common_mocks,
[call("switch", SERVICE_TURN_ON, {"entity_id": "switch.mock_switch"})], [call("switch", SERVICE_TURN_ON, {"entity_id": "switch.mock_switch"})],
) )
common_mocks.mock_is_state.return_value = True
# Call the keep-alive callback a few times (as if `async_track_time_interval` # Call the keep-alive callback a few times (as if `async_track_time_interval`
# had done it) and assert that the callback function is replaced each time. # had done it) and assert that the callback function is replaced each time.
@@ -240,6 +241,7 @@ class TestKeepAlive:
common_mocks, common_mocks,
[call("switch", SERVICE_TURN_OFF, {"entity_id": "switch.mock_switch"})], [call("switch", SERVICE_TURN_OFF, {"entity_id": "switch.mock_switch"})],
) )
common_mocks.mock_is_state.return_value = False
# Call the keep-alive callback a few times (as if `async_track_time_interval` # Call the keep-alive callback a few times (as if `async_track_time_interval`
# had done it) and assert that the callback function is replaced each time. # had done it) and assert that the callback function is replaced each time.