"""Support for MQTT message handling.""" import asyncio import datetime as dt import logging from itertools import groupby from operator import attrgetter from typing import Callable, List, Optional, Union import attr from homeassistant.core import callback, HomeAssistant from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers.dispatcher import dispatcher_send from homeassistant.util import dt as dt_util _LOGGER = logging.getLogger(__name__) DEFAULT_PORT = 1883 DEFAULT_KEEPALIVE = 60 PROTOCOL_311 = "3.1.1" DEFAULT_PROTOCOL = PROTOCOL_311 MQTT_CONNECTED = "blitzortung_mqtt_connected" MQTT_DISCONNECTED = "blitzortung_mqtt_disconnected" MAX_RECONNECT_WAIT = 300 # seconds def _raise_on_error(result_code: int) -> None: """Raise error if error result.""" # pylint: disable=import-outside-toplevel import paho.mqtt.client as mqtt if result_code != 0: raise HomeAssistantError( f"Error talking to MQTT: {mqtt.error_string(result_code)}" ) def _match_topic(subscription: str, topic: str) -> bool: """Test if topic matches subscription.""" # pylint: disable=import-outside-toplevel from paho.mqtt.matcher import MQTTMatcher matcher = MQTTMatcher() matcher[subscription] = True try: next(matcher.iter_match(topic)) return True except StopIteration: return False PublishPayloadType = Union[str, bytes, int, float, None] @attr.s(slots=True, frozen=True) class Message: """MQTT Message.""" topic = attr.ib(type=str) payload = attr.ib(type=PublishPayloadType) qos = attr.ib(type=int) retain = attr.ib(type=bool) subscribed_topic = attr.ib(type=str, default=None) timestamp = attr.ib(type=dt.datetime, default=None) MessageCallbackType = Callable[[Message], None] @attr.s(slots=True, frozen=True) class Subscription: """Class to hold data about an active subscription.""" topic = attr.ib(type=str) callback = attr.ib(type=MessageCallbackType) qos = attr.ib(type=int, default=0) encoding = attr.ib(type=str, default="utf-8") SubscribePayloadType = Union[str, bytes] # Only bytes if encoding is None class MQTT: """Home Assistant MQTT client.""" def __init__( self, hass: HomeAssistant, host, port=DEFAULT_PORT, keepalive=DEFAULT_KEEPALIVE, ) -> None: """Initialize Home Assistant MQTT client.""" # We don't import on the top because some integrations # should be able to optionally rely on MQTT. import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel self.hass = hass self.host = host self.port = port self.keepalive = keepalive self.subscriptions: List[Subscription] = [] self.connected = False self._mqttc: mqtt.Client = None self._paho_lock = asyncio.Lock() self.init_client() def init_client(self): """Initialize paho client.""" # We don't import on the top because some integrations # should be able to optionally rely on MQTT. import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel proto = mqtt.MQTTv311 self._mqttc = mqtt.Client(protocol=proto) self._mqttc.on_connect = self._mqtt_on_connect self._mqttc.on_disconnect = self._mqtt_on_disconnect self._mqttc.on_message = self._mqtt_on_message async def async_publish( self, topic: str, payload: PublishPayloadType, qos: int, retain: bool ) -> None: """Publish a MQTT message.""" async with self._paho_lock: _LOGGER.debug("Transmitting message on %s: %s", topic, payload) await self.hass.async_add_executor_job( self._mqttc.publish, topic, payload, qos, retain ) async def async_connect(self) -> str: """Connect to the host. Does not process messages yet.""" # pylint: disable=import-outside-toplevel import paho.mqtt.client as mqtt result: int = None try: result = await self.hass.async_add_executor_job( self._mqttc.connect, self.host, self.port, self.keepalive, ) except OSError as err: _LOGGER.error("Failed to connect to MQTT server due to exception: %s", err) if result is not None and result != 0: _LOGGER.error( "Failed to connect to MQTT server: %s", mqtt.error_string(result) ) self._mqttc.loop_start() async def async_disconnect(self): """Stop the MQTT client.""" def stop(): """Stop the MQTT client.""" self._mqttc.disconnect() self._mqttc.loop_stop() await self.hass.async_add_executor_job(stop) async def async_subscribe( self, topic: str, msg_callback, qos: int, encoding: Optional[str] = None, ) -> Callable[[], None]: """Set up a subscription to a topic with the provided qos. This method is a coroutine. """ if not isinstance(topic, str): raise HomeAssistantError("Topic needs to be a string!") subscription = Subscription(topic, msg_callback, qos, encoding) self.subscriptions.append(subscription) # Only subscribe if currently connected. if self.connected: await self._async_perform_subscription(topic, qos) @callback def async_remove() -> None: """Remove subscription.""" if subscription not in self.subscriptions: raise HomeAssistantError("Can't remove subscription twice") self.subscriptions.remove(subscription) if any(other.topic == topic for other in self.subscriptions): # Other subscriptions on topic remaining - don't unsubscribe. return # Only unsubscribe if currently connected. if self.connected: self.hass.async_create_task(self._async_unsubscribe(topic)) return async_remove async def _async_unsubscribe(self, topic: str) -> None: """Unsubscribe from a topic. This method is a coroutine. """ _LOGGER.debug("Unsubscribing from %s", topic) async with self._paho_lock: result: int = None result, _ = await self.hass.async_add_executor_job( self._mqttc.unsubscribe, topic ) _raise_on_error(result) async def _async_perform_subscription(self, topic: str, qos: int) -> None: """Perform a paho-mqtt subscription.""" _LOGGER.debug("Subscribing to %s", topic) async with self._paho_lock: result: int = None result, _ = await self.hass.async_add_executor_job( self._mqttc.subscribe, topic, qos ) _raise_on_error(result) def _mqtt_on_connect(self, _mqttc, _userdata, _flags, result_code: int) -> None: """On connect callback. Resubscribe to all topics we were subscribed to and publish birth message. """ # pylint: disable=import-outside-toplevel import paho.mqtt.client as mqtt if result_code != mqtt.CONNACK_ACCEPTED: _LOGGER.error( "Unable to connect to the MQTT broker: %s", mqtt.connack_string(result_code), ) return self.connected = True dispatcher_send(self.hass, MQTT_CONNECTED) _LOGGER.info( "Connected to MQTT server %s:%s (%s)", self.host, self.port, result_code, ) # Group subscriptions to only re-subscribe once for each topic. keyfunc = attrgetter("topic") for topic, subs in groupby(sorted(self.subscriptions, key=keyfunc), keyfunc): # Re-subscribe with the highest requested qos max_qos = max(subscription.qos for subscription in subs) self.hass.add_job(self._async_perform_subscription, topic, max_qos) def _mqtt_on_message(self, _mqttc, _userdata, msg) -> None: """Message received callback.""" self.hass.add_job(self._mqtt_handle_message, msg) @callback def _mqtt_handle_message(self, msg) -> None: _LOGGER.debug( "Received message on %s%s: %s", msg.topic, " (retained)" if msg.retain else "", msg.payload, ) timestamp = dt_util.utcnow() for subscription in self.subscriptions: if not _match_topic(subscription.topic, msg.topic): continue payload: SubscribePayloadType = msg.payload if subscription.encoding is not None: try: payload = msg.payload.decode(subscription.encoding) except (AttributeError, UnicodeDecodeError): _LOGGER.warning( "Can't decode payload %s on %s with encoding %s (for %s)", msg.payload, msg.topic, subscription.encoding, subscription.callback, ) continue self.hass.async_create_task( subscription.callback( Message( msg.topic, payload, msg.qos, msg.retain, subscription.topic, timestamp, ) ) ) def _mqtt_on_disconnect(self, _mqttc, _userdata, result_code: int) -> None: """Disconnected callback.""" self.connected = False dispatcher_send(self.hass, MQTT_DISCONNECTED) _LOGGER.info( "Disconnected from MQTT server %s:%s (%s)", self.host, self.port, result_code, )