306 lines
9.8 KiB
Python
306 lines
9.8 KiB
Python
"""Support for MQTT message handling."""
|
|
import asyncio
|
|
import datetime as dt
|
|
import logging
|
|
from itertools import groupby
|
|
from operator import attrgetter
|
|
from typing import Callable, List, Optional, Union
|
|
|
|
import attr
|
|
|
|
from homeassistant.core import callback, HomeAssistant
|
|
from homeassistant.exceptions import HomeAssistantError
|
|
from homeassistant.helpers.dispatcher import dispatcher_send
|
|
from homeassistant.util import dt as dt_util
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
DEFAULT_PORT = 1883
|
|
DEFAULT_KEEPALIVE = 60
|
|
PROTOCOL_311 = "3.1.1"
|
|
DEFAULT_PROTOCOL = PROTOCOL_311
|
|
MQTT_CONNECTED = "blitzortung_mqtt_connected"
|
|
MQTT_DISCONNECTED = "blitzortung_mqtt_disconnected"
|
|
|
|
|
|
MAX_RECONNECT_WAIT = 300 # seconds
|
|
|
|
|
|
def _raise_on_error(result_code: int) -> None:
|
|
"""Raise error if error result."""
|
|
# pylint: disable=import-outside-toplevel
|
|
import paho.mqtt.client as mqtt
|
|
|
|
if result_code != 0:
|
|
raise HomeAssistantError(
|
|
f"Error talking to MQTT: {mqtt.error_string(result_code)}"
|
|
)
|
|
|
|
|
|
def _match_topic(subscription: str, topic: str) -> bool:
|
|
"""Test if topic matches subscription."""
|
|
# pylint: disable=import-outside-toplevel
|
|
from paho.mqtt.matcher import MQTTMatcher
|
|
|
|
matcher = MQTTMatcher()
|
|
matcher[subscription] = True
|
|
try:
|
|
next(matcher.iter_match(topic))
|
|
return True
|
|
except StopIteration:
|
|
return False
|
|
|
|
|
|
PublishPayloadType = Union[str, bytes, int, float, None]
|
|
|
|
|
|
@attr.s(slots=True, frozen=True)
|
|
class Message:
|
|
"""MQTT Message."""
|
|
|
|
topic = attr.ib(type=str)
|
|
payload = attr.ib(type=PublishPayloadType)
|
|
qos = attr.ib(type=int)
|
|
retain = attr.ib(type=bool)
|
|
subscribed_topic = attr.ib(type=str, default=None)
|
|
timestamp = attr.ib(type=dt.datetime, default=None)
|
|
|
|
|
|
MessageCallbackType = Callable[[Message], None]
|
|
|
|
|
|
@attr.s(slots=True, frozen=True)
|
|
class Subscription:
|
|
"""Class to hold data about an active subscription."""
|
|
|
|
topic = attr.ib(type=str)
|
|
callback = attr.ib(type=MessageCallbackType)
|
|
qos = attr.ib(type=int, default=0)
|
|
encoding = attr.ib(type=str, default="utf-8")
|
|
|
|
|
|
SubscribePayloadType = Union[str, bytes] # Only bytes if encoding is None
|
|
|
|
|
|
class MQTT:
|
|
"""Home Assistant MQTT client."""
|
|
|
|
def __init__(
|
|
self,
|
|
hass: HomeAssistant,
|
|
host,
|
|
port=DEFAULT_PORT,
|
|
keepalive=DEFAULT_KEEPALIVE,
|
|
) -> None:
|
|
"""Initialize Home Assistant MQTT client."""
|
|
# We don't import on the top because some integrations
|
|
# should be able to optionally rely on MQTT.
|
|
import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel
|
|
|
|
self.hass = hass
|
|
self.host = host
|
|
self.port = port
|
|
self.keepalive = keepalive
|
|
self.subscriptions: List[Subscription] = []
|
|
self.connected = False
|
|
self._mqttc: mqtt.Client = None
|
|
self._paho_lock = asyncio.Lock()
|
|
|
|
self.init_client()
|
|
|
|
def init_client(self):
|
|
"""Initialize paho client."""
|
|
# We don't import on the top because some integrations
|
|
# should be able to optionally rely on MQTT.
|
|
import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel
|
|
|
|
proto = mqtt.MQTTv311
|
|
self._mqttc = mqtt.Client(protocol=proto)
|
|
|
|
self._mqttc.on_connect = self._mqtt_on_connect
|
|
self._mqttc.on_disconnect = self._mqtt_on_disconnect
|
|
self._mqttc.on_message = self._mqtt_on_message
|
|
|
|
async def async_publish(
|
|
self, topic: str, payload: PublishPayloadType, qos: int, retain: bool
|
|
) -> None:
|
|
"""Publish a MQTT message."""
|
|
async with self._paho_lock:
|
|
_LOGGER.debug("Transmitting message on %s: %s", topic, payload)
|
|
await self.hass.async_add_executor_job(
|
|
self._mqttc.publish, topic, payload, qos, retain
|
|
)
|
|
|
|
async def async_connect(self) -> str:
|
|
"""Connect to the host. Does not process messages yet."""
|
|
# pylint: disable=import-outside-toplevel
|
|
import paho.mqtt.client as mqtt
|
|
|
|
result: int = None
|
|
try:
|
|
result = await self.hass.async_add_executor_job(
|
|
self._mqttc.connect, self.host, self.port, self.keepalive,
|
|
)
|
|
except OSError as err:
|
|
_LOGGER.error("Failed to connect to MQTT server due to exception: %s", err)
|
|
|
|
if result is not None and result != 0:
|
|
_LOGGER.error(
|
|
"Failed to connect to MQTT server: %s", mqtt.error_string(result)
|
|
)
|
|
|
|
self._mqttc.loop_start()
|
|
|
|
async def async_disconnect(self):
|
|
"""Stop the MQTT client."""
|
|
|
|
def stop():
|
|
"""Stop the MQTT client."""
|
|
self._mqttc.disconnect()
|
|
self._mqttc.loop_stop()
|
|
|
|
await self.hass.async_add_executor_job(stop)
|
|
|
|
async def async_subscribe(
|
|
self, topic: str, msg_callback, qos: int, encoding: Optional[str] = None,
|
|
) -> Callable[[], None]:
|
|
"""Set up a subscription to a topic with the provided qos.
|
|
|
|
This method is a coroutine.
|
|
"""
|
|
if not isinstance(topic, str):
|
|
raise HomeAssistantError("Topic needs to be a string!")
|
|
|
|
subscription = Subscription(topic, msg_callback, qos, encoding)
|
|
self.subscriptions.append(subscription)
|
|
|
|
# Only subscribe if currently connected.
|
|
if self.connected:
|
|
await self._async_perform_subscription(topic, qos)
|
|
|
|
@callback
|
|
def async_remove() -> None:
|
|
"""Remove subscription."""
|
|
if subscription not in self.subscriptions:
|
|
raise HomeAssistantError("Can't remove subscription twice")
|
|
self.subscriptions.remove(subscription)
|
|
|
|
if any(other.topic == topic for other in self.subscriptions):
|
|
# Other subscriptions on topic remaining - don't unsubscribe.
|
|
return
|
|
|
|
# Only unsubscribe if currently connected.
|
|
if self.connected:
|
|
self.hass.async_create_task(self._async_unsubscribe(topic))
|
|
|
|
return async_remove
|
|
|
|
async def _async_unsubscribe(self, topic: str) -> None:
|
|
"""Unsubscribe from a topic.
|
|
|
|
This method is a coroutine.
|
|
"""
|
|
_LOGGER.debug("Unsubscribing from %s", topic)
|
|
async with self._paho_lock:
|
|
result: int = None
|
|
result, _ = await self.hass.async_add_executor_job(
|
|
self._mqttc.unsubscribe, topic
|
|
)
|
|
_raise_on_error(result)
|
|
|
|
async def _async_perform_subscription(self, topic: str, qos: int) -> None:
|
|
"""Perform a paho-mqtt subscription."""
|
|
_LOGGER.debug("Subscribing to %s", topic)
|
|
|
|
async with self._paho_lock:
|
|
result: int = None
|
|
result, _ = await self.hass.async_add_executor_job(
|
|
self._mqttc.subscribe, topic, qos
|
|
)
|
|
_raise_on_error(result)
|
|
|
|
def _mqtt_on_connect(self, _mqttc, _userdata, _flags, result_code: int) -> None:
|
|
"""On connect callback.
|
|
|
|
Resubscribe to all topics we were subscribed to and publish birth
|
|
message.
|
|
"""
|
|
# pylint: disable=import-outside-toplevel
|
|
import paho.mqtt.client as mqtt
|
|
|
|
if result_code != mqtt.CONNACK_ACCEPTED:
|
|
_LOGGER.error(
|
|
"Unable to connect to the MQTT broker: %s",
|
|
mqtt.connack_string(result_code),
|
|
)
|
|
return
|
|
|
|
self.connected = True
|
|
dispatcher_send(self.hass, MQTT_CONNECTED)
|
|
_LOGGER.info(
|
|
"Connected to MQTT server %s:%s (%s)", self.host, self.port, result_code,
|
|
)
|
|
|
|
# Group subscriptions to only re-subscribe once for each topic.
|
|
keyfunc = attrgetter("topic")
|
|
for topic, subs in groupby(sorted(self.subscriptions, key=keyfunc), keyfunc):
|
|
# Re-subscribe with the highest requested qos
|
|
max_qos = max(subscription.qos for subscription in subs)
|
|
self.hass.add_job(self._async_perform_subscription, topic, max_qos)
|
|
|
|
def _mqtt_on_message(self, _mqttc, _userdata, msg) -> None:
|
|
"""Message received callback."""
|
|
self.hass.add_job(self._mqtt_handle_message, msg)
|
|
|
|
@callback
|
|
def _mqtt_handle_message(self, msg) -> None:
|
|
_LOGGER.debug(
|
|
"Received message on %s%s: %s",
|
|
msg.topic,
|
|
" (retained)" if msg.retain else "",
|
|
msg.payload,
|
|
)
|
|
timestamp = dt_util.utcnow()
|
|
|
|
for subscription in self.subscriptions:
|
|
if not _match_topic(subscription.topic, msg.topic):
|
|
continue
|
|
|
|
payload: SubscribePayloadType = msg.payload
|
|
if subscription.encoding is not None:
|
|
try:
|
|
payload = msg.payload.decode(subscription.encoding)
|
|
except (AttributeError, UnicodeDecodeError):
|
|
_LOGGER.warning(
|
|
"Can't decode payload %s on %s with encoding %s (for %s)",
|
|
msg.payload,
|
|
msg.topic,
|
|
subscription.encoding,
|
|
subscription.callback,
|
|
)
|
|
continue
|
|
|
|
self.hass.async_create_task(
|
|
subscription.callback(
|
|
Message(
|
|
msg.topic,
|
|
payload,
|
|
msg.qos,
|
|
msg.retain,
|
|
subscription.topic,
|
|
timestamp,
|
|
)
|
|
)
|
|
)
|
|
|
|
def _mqtt_on_disconnect(self, _mqttc, _userdata, result_code: int) -> None:
|
|
"""Disconnected callback."""
|
|
self.connected = False
|
|
dispatcher_send(self.hass, MQTT_DISCONNECTED)
|
|
_LOGGER.info(
|
|
"Disconnected from MQTT server %s:%s (%s)",
|
|
self.host,
|
|
self.port,
|
|
result_code,
|
|
)
|