diff --git a/README.md b/README.md index 4270216..4704472 100644 --- a/README.md +++ b/README.md @@ -28,12 +28,12 @@ Tested: SH10RT services: sungrowmodbus2mqtt: container_name: sungrowmodbus2mqtt - image: ghcr.io/jesseklm/sungrowmodbus2mqtt:v1.0.20 + image: ghcr.io/jesseklm/sungrowmodbus2mqtt:v1.0.21 restart: unless-stopped volumes: - ./config.sh10rt.yaml:/config/config.yaml:ro ``` -- `wget -O config.sh10rt.yaml https://raw.githubusercontent.com/jesseklm/sungrowmodbus2mqtt/v1.0.20/config.sh10rt.example.yaml` +- `wget -O config.sh10rt.yaml https://raw.githubusercontent.com/jesseklm/sungrowmodbus2mqtt/v1.0.21/config.sh10rt.example.yaml` - adjust your config yaml - `docker compose up -d` diff --git a/config.py b/config.py index 1fc52fe..437882d 100644 --- a/config.py +++ b/config.py @@ -14,18 +14,18 @@ def get_config_local(filename: Path) -> dict: def get_first_config() -> dict: - files = [ + files: list[Path] = [ Path('/config/config.yaml'), Path('config.yaml'), Path('config.sh10rt.example.yaml') ] for file in files: if file.exists(): - loaded_config = get_config_local(file) + loaded_config: dict = get_config_local(file) break else: raise FileNotFoundError - options_files = [ + options_files: list[Path] = [ Path('/data/options.json'), Path('/data/options.yaml'), ] @@ -35,7 +35,7 @@ def get_first_config() -> dict: with open(options_file) as file: options: dict = json.load(file) else: - options = get_config_local(options_file) + options: dict = get_config_local(options_file) for key, option in options.items(): if isinstance(option, str) and option: loaded_config[key] = option diff --git a/modbus_handler.py b/modbus_handler.py index c9ea054..b839fd8 100644 --- a/modbus_handler.py +++ b/modbus_handler.py @@ -1,15 +1,17 @@ import logging import time +from typing import Callable from pymodbus.client import ModbusTcpClient from pymodbus.constants import Endian from pymodbus.exceptions import ConnectionException from pymodbus.payload import BinaryPayloadDecoder +from pymodbus.pdu import ModbusResponse from SungrowModbusTcpClient.SungrowModbusTcpClient import SungrowModbusTcpClient class ModbusHandler: - WORD_COUNT = { + WORD_COUNT: dict[str, int] = { 'uint16': 1, 'int16': 1, 'uint32': 2, @@ -18,18 +20,17 @@ class ModbusHandler: 'int64': 4, } - def __init__(self, config: dict): + def __init__(self, config: dict) -> None: self.host: str = config['ip'] self.port: int = config.get('port', 502) self.slave_id: int = config.get('slave_id', 0x1) self.byte_order: Endian = Endian.BIG if config.get('byte_order', 'big') == 'big' else Endian.LITTLE self.word_order: Endian = Endian.BIG if config.get('word_order', 'little') == 'big' else Endian.LITTLE - if config.get('sungrow_encrypted', False): - self.modbus_client = SungrowModbusTcpClient(host=self.host, port=self.port, timeout=10, retries=1) - else: - self.modbus_client = ModbusTcpClient(host=self.host, port=self.port, timeout=10, retries=1) + modbus_class: type[ModbusTcpClient] = SungrowModbusTcpClient if config.get('sungrow_encrypted', + False) else ModbusTcpClient + self.modbus_client: ModbusTcpClient = modbus_class(host=self.host, port=self.port, timeout=10, retries=1) - def reconnect(self, first_connect=False): + def reconnect(self, first_connect=False) -> None: while True: try: if self.modbus_client.connect(): @@ -43,9 +44,9 @@ def read(self, table: str, address: int, count: int) -> list[int]: while True: try: if table == 'holding': - result = self.modbus_client.read_holding_registers(address, count, self.slave_id) + result: ModbusResponse = self.modbus_client.read_holding_registers(address, count, self.slave_id) elif table == 'input': - result = self.modbus_client.read_input_registers(address, count, self.slave_id) + result: ModbusResponse = self.modbus_client.read_input_registers(address, count, self.slave_id) else: raise Exception('Invalid table') except (ConnectionResetError, ConnectionException) as e: @@ -59,13 +60,14 @@ def read(self, table: str, address: int, count: int) -> list[int]: continue return result.registers - def close(self): + def close(self) -> None: self.modbus_client.close() logging.info('modbus closed.') def decode(self, registers: list[int], datatype: str) -> int: - decoder = BinaryPayloadDecoder.fromRegisters(registers, byteorder=self.byte_order, wordorder=self.word_order) - decode_methods = { + decoder: BinaryPayloadDecoder = BinaryPayloadDecoder.fromRegisters(registers, byteorder=self.byte_order, + wordorder=self.word_order) + decode_methods: dict[str, Callable[[], int]] = { 'uint16': decoder.decode_16bit_uint, 'int16': decoder.decode_16bit_int, 'uint32': decoder.decode_32bit_uint, diff --git a/mqtt_handler.py b/mqtt_handler.py index d4d330b..6b02f17 100644 --- a/mqtt_handler.py +++ b/mqtt_handler.py @@ -2,50 +2,54 @@ import queue import threading import time +from typing import Any import paho.mqtt.client as mqtt +from paho.mqtt.client import ConnectFlags from paho.mqtt.packettypes import PacketTypes +from paho.mqtt.properties import Properties from paho.mqtt.reasoncodes import ReasonCode class MqttHandler: - def __init__(self, config: dict): + def __init__(self, config: dict) -> None: self.topic_prefix: str = config.get('mqtt_topic', 'sungrowmodbus2mqtt/').rstrip('/') + '/' self.host: str = config['mqtt_server'] self.port: int = config.get('mqtt_port', 1883) - self.mqttc = mqtt.Client(mqtt.CallbackAPIVersion.VERSION2) + self.mqttc: mqtt.Client = mqtt.Client(mqtt.CallbackAPIVersion.VERSION2) self.mqttc.on_connect = self.on_connect self.mqttc.username_pw_set(config['mqtt_username'], config['mqtt_password']) self.mqttc.will_set(self.topic_prefix + 'available', 'offline', retain=True) self.mqttc.connect_async(host=self.host, port=self.port) self.mqttc.loop_start() - self.publishing_queue = queue.Queue() + self.publishing_queue: queue.Queue[dict[str, Any]] = queue.Queue() - self.publishing_thread = threading.Thread(target=self.publishing_handler, daemon=True) + self.publishing_thread: threading.Thread = threading.Thread(target=self.publishing_handler, daemon=True) self.publishing_thread.start() - def on_connect(self, client, userdata, connect_flags, reason_code, properties): + def on_connect(self, client: mqtt.Client, userdata: Any, connect_flags: ConnectFlags, reason_code: ReasonCode, + properties: Properties | None) -> None: self.mqttc.publish(self.topic_prefix + 'available', 'online', retain=True) if reason_code == ReasonCode(PacketTypes.CONNACK, 'Success'): logging.info('mqtt connected.') else: logging.error(f'mqtt connection to %s:%s failed, %s.', self.host, self.port, reason_code) - def publish(self, topic: str, payload: str | int | float, retain=False): + def publish(self, topic: str, payload: str | int | float, retain=False) -> None: self.publishing_queue.put({ 'topic': self.topic_prefix + topic, 'payload': payload, 'retain': retain, }) - def publishing_handler(self): + def publishing_handler(self) -> None: while True: - message: dict = self.publishing_queue.get() + message: dict[str, Any] = self.publishing_queue.get() while not self.mqttc.is_connected(): time.sleep(1) - result = self.mqttc.publish(**message) + result: mqtt.MQTTMessageInfo = self.mqttc.publish(**message) if result.rc != mqtt.MQTT_ERR_SUCCESS: logging.error(f'mqtt publish failed: %s %s.', message, result) self.publishing_queue.task_done() diff --git a/sungrowmodbus2mqtt.py b/sungrowmodbus2mqtt.py index 334adea..6f8f1ab 100644 --- a/sungrowmodbus2mqtt.py +++ b/sungrowmodbus2mqtt.py @@ -2,107 +2,109 @@ import signal import sys import time +from typing import Any from config import get_first_config from modbus_handler import ModbusHandler from mqtt_handler import MqttHandler -__version__ = '1.0.20' +__version__ = '1.0.21' class SungrowModbus2Mqtt: - def __init__(self): + def __init__(self) -> None: signal.signal(signal.SIGINT, self.exit_handler) signal.signal(signal.SIGTERM, self.exit_handler) - config = get_first_config() + config: dict = get_first_config() if 'logging' in config: - logging_level_name = config['logging'].upper() - logging_level = logging.getLevelNamesMapping().get(logging_level_name, logging.NOTSET) + logging_level_name: str = config['logging'].upper() + logging_level: int = logging.getLevelNamesMapping().get(logging_level_name, logging.NOTSET) if logging_level != logging.NOTSET: logging.getLogger().setLevel(logging_level) else: logging.warning(f'unknown logging level: %s.', logging_level) - self.mqtt_handler = MqttHandler(config) - self.modbus_handler = ModbusHandler(config) + self.mqtt_handler: MqttHandler = MqttHandler(config) + self.modbus_handler: ModbusHandler = ModbusHandler(config) self.modbus_handler.reconnect(first_connect=True) self.address_offset: int = config.get('address_offset', 0) self.old_value_map: bool = config.get('old_value_map', False) self.scan_batching: int = config.get('scan_batching', 100) self.update_rate: int = config.get('update_rate', 2) - self.registers = { + self.registers: dict[str, dict[int, dict[str, Any]]] = { 'holding': {}, 'input': {}, } self.init_registers(config) - def loop(self): + def loop(self) -> None: while True: - start_time = time.perf_counter() + start_time: float = time.perf_counter() self.read(start_time) self.publish() - time_taken = time.perf_counter() - start_time - time_to_sleep = self.update_rate - time_taken + time_taken: float = time.perf_counter() - start_time + time_to_sleep: float = self.update_rate - time_taken logging.debug('looped in %.2fms, sleeping %.2fs.', time_taken * 1000, time_to_sleep) if time_to_sleep > 0: time.sleep(time_to_sleep) - def exit_handler(self, signum, frame): + def exit_handler(self, signum, frame) -> None: self.modbus_handler.close() sys.exit(0) - def add_dummy_register(self, register_table: str, address: int): + def add_dummy_register(self, register_table: str, address: int) -> None: self.registers[register_table][address] = {'type': 'dummy'} - def create_register(self, register_table: str, config_register: dict) -> dict: - register = { + def create_register(self, register_table: str, config_register: dict) -> dict[str, Any]: + register: dict[str, Any] = { 'topic': config_register['pub_topic'], 'type': config_register.get('type', 'uint16').strip().lower(), } if 'value_map' in config_register: - value_map = config_register['value_map'] + value_map: dict = config_register['value_map'] if self.old_value_map: value_map = {v: k for k, v in value_map.items()} register['map'] = value_map for option in ['scale', 'mask', 'shift', 'retain']: if option in config_register: register[option] = config_register[option] - word_count = ModbusHandler.WORD_COUNT.get(register['type'], 1) + word_count: int = ModbusHandler.WORD_COUNT.get(register['type'], 1) for i in range(1, word_count): self.add_dummy_register(register_table, config_register['address'] + self.address_offset + i) return register - def init_register(self, register_table: str, register: dict): - new_register = self.create_register(register_table, register) + def init_register(self, register_table: str, register: dict) -> None: + new_register: dict[str, Any] = self.create_register(register_table, register) register_address: int = register['address'] + self.address_offset - existing_register = self.registers[register_table].setdefault(register_address, new_register) + existing_register: dict[str, Any] = self.registers[register_table].setdefault(register_address, new_register) if existing_register is not new_register: existing_register.setdefault('multi', []).append(new_register) - def init_registers(self, config: dict): + def init_registers(self, config: dict) -> None: for register_type in ['registers', 'input', 'holding']: for register in config.get(register_type, []): - register_table = register.get('table', 'holding') if register_type == 'registers' else register_type + register_table: str = register.get('table', + 'holding') if register_type == 'registers' else register_type self.init_register(register_table, register) self.registers = {table: dict(sorted(register.items())) for table, register in self.registers.items()} - def read(self, start_time: float): + def read(self, start_time: float) -> None: for table, table_registers in self.registers.items(): for address, register in list(table_registers.items()): if start_time - register.get('last_fetch', 0) < self.update_rate - 0.001: continue - count = register.get('read_count', self.scan_batching) + count: int = register.get('read_count', self.scan_batching) if 'read_count' not in register: count = next((i + 1 for i in range(count - 1, -1, -1) if address + i in table_registers)) register['read_count'] = count logging.debug(f'read: table:%s address:%s count:%s.', table, address, count) - result = self.modbus_handler.read(table, address, count) + result: list[int] = self.modbus_handler.read(table, address, count) for result_address, result_register in enumerate(result, start=address): if result_address not in table_registers: continue - table_register = table_registers[result_address] + table_register: dict[str, Any] = table_registers[result_address] table_register['last_fetch'] = start_time if table_register.get('value') == result_register: table_register['new'] = False @@ -111,7 +113,7 @@ def read(self, start_time: float): table_register['new'] = True @staticmethod - def prepare_value(register: dict, value: int) -> str | int | float: + def prepare_value(register: dict[str, Any], value: int) -> str | int | float: if value_map := register.get('map'): return value_map.get(value, f'{value:#x} not mapped!') if mask := register.get('mask'): @@ -119,19 +121,19 @@ def prepare_value(register: dict, value: int) -> str | int | float: if shift := register.get('shift'): value >>= shift if scale := register.get('scale'): - value = round(value * scale, 10) + value: int | float = round(value * scale, 10) return value - def publish(self): + def publish(self) -> None: for table, table_registers in self.registers.items(): for address, register in table_registers.items(): if (register_type := register['type']) == 'dummy': continue - word_count = ModbusHandler.WORD_COUNT.get(register_type, 1) + word_count: int = ModbusHandler.WORD_COUNT.get(register_type, 1) if not any(table_registers[address + i].get('new', False) for i in range(word_count)): continue values: list[int] = [table_registers[address + i]['value'] for i in range(word_count)] - value = self.modbus_handler.decode(values, register_type) + value: int = self.modbus_handler.decode(values, register_type) for subregister in register.get('multi', []): self.mqtt_handler.publish(subregister['topic'], self.prepare_value(subregister, value), subregister.get('retain', False)) @@ -143,5 +145,5 @@ def publish(self): logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logging.getLogger('pymodbus').setLevel(logging.INFO) logging.info(f'starting SungrowModbus2Mqtt v%s.', __version__) - app = SungrowModbus2Mqtt() + app: SungrowModbus2Mqtt = SungrowModbus2Mqtt() app.loop()