diff --git a/setup.py b/setup.py index 1eec06d..753b719 100755 --- a/setup.py +++ b/setup.py @@ -10,14 +10,14 @@ python_requires=">=3.7", install_requires=[ 'pyserial-asyncio', - 'websockets<9', # hbmqtt currently (2021-06-07) not compatible with >=9 + 'websockets', 'sanic', 'bitstruct', 'attrs>=17.3.0', 'structattr', 'sortedcontainers', 'python-dateutil ', - 'hbmqtt', + 'paho-mqtt', ], setup_requires=[ 'pytest-runner' diff --git a/src/velbus/__main__.py b/src/velbus/__main__.py index 3c8dcf0..13563a8 100644 --- a/src/velbus/__main__.py +++ b/src/velbus/__main__.py @@ -134,9 +134,17 @@ def handle_sighup(): match = re.fullmatch("(?P[a-z]+)://(?P[^/]+)(?P/.*)", uri) if not match: raise ValueError(f"Invalid MQTT URI `{uri}`") - mqtt_uri = f"{match.group('proto')}://{match.group('host')}" + + host = match.group('host') + if ':' in host: + host, port = host.split(':') + port = int(port) + else: + port = 1883 + + mqtt_uri = f"{match.group('proto')}://{host}:{port}" mqtt_topic = match.group('topic')[1:] - sync = MqttStateSync(mqtt_uri=mqtt_uri, mqtt_topic_prefix=mqtt_topic) + sync = MqttStateSync(mqtt_host=host, mqtt_port=port, mqtt_topic_prefix=mqtt_topic) asyncio.get_event_loop().run_until_complete(sync.connect()) # may raise logger.info(f"Connected to MQTT {mqtt_uri}, topic prefix {mqtt_topic}") HttpApi.mqtt_sync_clients.add(sync) diff --git a/src/velbus/mqtt.py b/src/velbus/mqtt.py index b5e6e42..53ff3a9 100644 --- a/src/velbus/mqtt.py +++ b/src/velbus/mqtt.py @@ -1,46 +1,62 @@ +import typing import asyncio +import queue -import hbmqtt.client -import hbmqtt.mqtt.constants -import typing +import paho.mqtt.client as mqtt from .JsonPatchDict import JsonPatchOperation class MqttStateSync: def __init__(self, - mqtt_uri: str = "mqtt://localhost", + mqtt_host: str = "localhost", + mqtt_port: int = 1883, mqtt_topic_prefix: str = "", ): - self.mqtt_uri = mqtt_uri + self.mqtt_host = mqtt_host + self.mqtt_port = mqtt_port self.mqtt_topic_prefix = mqtt_topic_prefix - self.connection = hbmqtt.client.MQTTClient() + self.connection = mqtt.Client() + self._loop = asyncio.get_event_loop() + self._connected = asyncio.Future() async def connect(self): - await self.connection.connect(uri=self.mqtt_uri) + self.connection.on_connect = self._mqtt_thread_on_connect + self.connection.connect(host=self.mqtt_host, port=self.mqtt_port) + self.connection.loop_start() # in separate thread + await self._connected + + def _mqtt_thread_on_connect(self, client, userdata, flags, rc): + self._loop.call_soon_threadsafe(self._on_connect, client, userdata, flags, rc) + + def _on_connect(self, client, userdata, flags, rc): + if rc == 0: + self._connected.set_result(True) + else: + self._connected.set_exception(RuntimeError("connection failed")) def __hash__(self) -> int: - return hash((self.mqtt_uri, self.mqtt_topic_prefix)) + return hash((self.mqtt_host, self.mqtt_port, self.mqtt_topic_prefix)) def __eq__(self, other) -> bool: if not isinstance(other, MqttStateSync): return False - return (self.mqtt_uri, self.mqtt_topic_prefix) == (other.mqtt_uri, other.mqtt_topic_prefix) + return (self.mqtt_host, self.mqtt_port, self.mqtt_topic_prefix) \ + == (other.mqtt_host, other.mqtt_port, other.mqtt_topic_prefix) def __repr__(self) -> str: - return f"{self.__class__.__name__}(mqtt_uri={repr(self.mqtt_uri)}, " \ + return f"{self.__class__.__name__}(mqtt_host={repr(self.mqtt_host)}, " \ + f"mqtt_port={repr(self.mqtt_port)}, " \ f"mqtt_topic_prefix={repr(self.mqtt_topic_prefix)})" async def publish(self, op: JsonPatchOperation) -> None: if op.op == JsonPatchOperation.Operation.remove: - return await self.publish_single(path=op.path, value=b"") + return self.publish_single(path=op.path, value=b"") # else: # replace or add - coroutines = [] for simple_op in op.decompose(): - coroutines.append(self.publish_single(path=simple_op.path, value=str(simple_op.value).encode('utf-8'))) - await asyncio.gather(*coroutines) + self.publish_single(path=simple_op.path, value=str(simple_op.value).encode('utf-8')) - async def publish_single(self, path: typing.List[str], value: bytes) -> None: + def publish_single(self, path: typing.List[str], value: bytes) -> None: topic = self.mqtt_topic_prefix + '/' + '/'.join(path) - await self.connection.publish(topic=topic, message=value, qos=hbmqtt.mqtt.constants.QOS_2, retain=True) + self.connection.publish(topic=topic, payload=value, qos=2, retain=True) diff --git a/tests/mqtt_test.py b/tests/mqtt_test.py index 6187f6f..7c294a7 100644 --- a/tests/mqtt_test.py +++ b/tests/mqtt_test.py @@ -1,5 +1,3 @@ -import asyncio - import pytest from unittest import mock @@ -12,14 +10,10 @@ async def test_publish(): sync = MqttStateSync() # don't call connect - pub_mock = mock.Mock() - async def async_pub_mock(*args, **kwargs): - await asyncio.sleep(0) - return pub_mock(*args, **kwargs) - sync.publish_single = async_pub_mock + sync.publish_single = mock.Mock() await sync.publish(JsonPatchOperation(JsonPatchOperation.Operation.add, ['test', '123'], 'foo')) - pub_mock.assert_called_once_with(path=['test', '123'], value=b'foo') + sync.publish_single.assert_called_once_with(path=['test', '123'], value=b'foo') @pytest.mark.asyncio @@ -27,17 +21,13 @@ async def test_publish_decompose(): sync = MqttStateSync() # don't call connect - pub_mock = mock.Mock() - async def async_pub_mock(*args, **kwargs): - await asyncio.sleep(0) - return pub_mock(*args, **kwargs) - sync.publish_single = async_pub_mock + sync.publish_single = mock.Mock() await sync.publish(JsonPatchOperation(JsonPatchOperation.Operation.add, ['test', '456'], {'hello': 'world', 'foo': True})) - pub_mock.assert_has_calls([ + sync.publish_single.assert_has_calls([ mock.call(path=['test', '456', 'hello'], value=b'world'), mock.call(path=['test', '456', 'foo'], value=b'True'), ], any_order=True)