Skip to content

Commit

Permalink
v1.0.21: perf_counter and typing
Browse files Browse the repository at this point in the history
  • Loading branch information
jesseklm committed Sep 1, 2024
1 parent 03fc268 commit a785396
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 60 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ Tested: SH10RT
services:
sungrowmodbus2mqtt:
container_name: sungrowmodbus2mqtt
image: ghcr.io/jesseklm/sungrowmodbus2mqtt:v1.0.20
image: ghcr.io/jesseklm/sungrowmodbus2mqtt:v1.0.21
restart: unless-stopped
volumes:
- ./config.sh10rt.yaml:/config/config.yaml:ro
```
- `wget -O config.sh10rt.yaml https://raw.githubusercontent.com/jesseklm/sungrowmodbus2mqtt/v1.0.20/config.sh10rt.example.yaml`
- `wget -O config.sh10rt.yaml https://raw.githubusercontent.com/jesseklm/sungrowmodbus2mqtt/v1.0.21/config.sh10rt.example.yaml`
- adjust your config yaml
- `docker compose up -d`
8 changes: 4 additions & 4 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,18 @@ def get_config_local(filename: Path) -> dict:


def get_first_config() -> dict:
files = [
files: list[Path] = [
Path('/config/config.yaml'),
Path('config.yaml'),
Path('config.sh10rt.example.yaml')
]
for file in files:
if file.exists():
loaded_config = get_config_local(file)
loaded_config: dict = get_config_local(file)
break
else:
raise FileNotFoundError
options_files = [
options_files: list[Path] = [
Path('/data/options.json'),
Path('/data/options.yaml'),
]
Expand All @@ -35,7 +35,7 @@ def get_first_config() -> dict:
with open(options_file) as file:
options: dict = json.load(file)
else:
options = get_config_local(options_file)
options: dict = get_config_local(options_file)
for key, option in options.items():
if isinstance(option, str) and option:
loaded_config[key] = option
Expand Down
26 changes: 14 additions & 12 deletions modbus_handler.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import logging
import time
from typing import Callable

from pymodbus.client import ModbusTcpClient
from pymodbus.constants import Endian
from pymodbus.exceptions import ConnectionException
from pymodbus.payload import BinaryPayloadDecoder
from pymodbus.pdu import ModbusResponse
from SungrowModbusTcpClient.SungrowModbusTcpClient import SungrowModbusTcpClient


class ModbusHandler:
WORD_COUNT = {
WORD_COUNT: dict[str, int] = {
'uint16': 1,
'int16': 1,
'uint32': 2,
Expand All @@ -18,18 +20,17 @@ class ModbusHandler:
'int64': 4,
}

def __init__(self, config: dict):
def __init__(self, config: dict) -> None:
self.host: str = config['ip']
self.port: int = config.get('port', 502)
self.slave_id: int = config.get('slave_id', 0x1)
self.byte_order: Endian = Endian.BIG if config.get('byte_order', 'big') == 'big' else Endian.LITTLE
self.word_order: Endian = Endian.BIG if config.get('word_order', 'little') == 'big' else Endian.LITTLE
if config.get('sungrow_encrypted', False):
self.modbus_client = SungrowModbusTcpClient(host=self.host, port=self.port, timeout=10, retries=1)
else:
self.modbus_client = ModbusTcpClient(host=self.host, port=self.port, timeout=10, retries=1)
modbus_class: type[ModbusTcpClient] = SungrowModbusTcpClient if config.get('sungrow_encrypted',
False) else ModbusTcpClient
self.modbus_client: ModbusTcpClient = modbus_class(host=self.host, port=self.port, timeout=10, retries=1)

def reconnect(self, first_connect=False):
def reconnect(self, first_connect=False) -> None:
while True:
try:
if self.modbus_client.connect():
Expand All @@ -43,9 +44,9 @@ def read(self, table: str, address: int, count: int) -> list[int]:
while True:
try:
if table == 'holding':
result = self.modbus_client.read_holding_registers(address, count, self.slave_id)
result: ModbusResponse = self.modbus_client.read_holding_registers(address, count, self.slave_id)
elif table == 'input':
result = self.modbus_client.read_input_registers(address, count, self.slave_id)
result: ModbusResponse = self.modbus_client.read_input_registers(address, count, self.slave_id)
else:
raise Exception('Invalid table')
except (ConnectionResetError, ConnectionException) as e:
Expand All @@ -59,13 +60,14 @@ def read(self, table: str, address: int, count: int) -> list[int]:
continue
return result.registers

def close(self):
def close(self) -> None:
self.modbus_client.close()
logging.info('modbus closed.')

def decode(self, registers: list[int], datatype: str) -> int:
decoder = BinaryPayloadDecoder.fromRegisters(registers, byteorder=self.byte_order, wordorder=self.word_order)
decode_methods = {
decoder: BinaryPayloadDecoder = BinaryPayloadDecoder.fromRegisters(registers, byteorder=self.byte_order,
wordorder=self.word_order)
decode_methods: dict[str, Callable[[], int]] = {
'uint16': decoder.decode_16bit_uint,
'int16': decoder.decode_16bit_int,
'uint32': decoder.decode_32bit_uint,
Expand Down
22 changes: 13 additions & 9 deletions mqtt_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,50 +2,54 @@
import queue
import threading
import time
from typing import Any

import paho.mqtt.client as mqtt
from paho.mqtt.client import ConnectFlags
from paho.mqtt.packettypes import PacketTypes
from paho.mqtt.properties import Properties
from paho.mqtt.reasoncodes import ReasonCode


class MqttHandler:
def __init__(self, config: dict):
def __init__(self, config: dict) -> None:
self.topic_prefix: str = config.get('mqtt_topic', 'sungrowmodbus2mqtt/').rstrip('/') + '/'
self.host: str = config['mqtt_server']
self.port: int = config.get('mqtt_port', 1883)

self.mqttc = mqtt.Client(mqtt.CallbackAPIVersion.VERSION2)
self.mqttc: mqtt.Client = mqtt.Client(mqtt.CallbackAPIVersion.VERSION2)
self.mqttc.on_connect = self.on_connect
self.mqttc.username_pw_set(config['mqtt_username'], config['mqtt_password'])
self.mqttc.will_set(self.topic_prefix + 'available', 'offline', retain=True)
self.mqttc.connect_async(host=self.host, port=self.port)
self.mqttc.loop_start()

self.publishing_queue = queue.Queue()
self.publishing_queue: queue.Queue[dict[str, Any]] = queue.Queue()

self.publishing_thread = threading.Thread(target=self.publishing_handler, daemon=True)
self.publishing_thread: threading.Thread = threading.Thread(target=self.publishing_handler, daemon=True)
self.publishing_thread.start()

def on_connect(self, client, userdata, connect_flags, reason_code, properties):
def on_connect(self, client: mqtt.Client, userdata: Any, connect_flags: ConnectFlags, reason_code: ReasonCode,
properties: Properties | None) -> None:
self.mqttc.publish(self.topic_prefix + 'available', 'online', retain=True)
if reason_code == ReasonCode(PacketTypes.CONNACK, 'Success'):
logging.info('mqtt connected.')
else:
logging.error(f'mqtt connection to %s:%s failed, %s.', self.host, self.port, reason_code)

def publish(self, topic: str, payload: str | int | float, retain=False):
def publish(self, topic: str, payload: str | int | float, retain=False) -> None:
self.publishing_queue.put({
'topic': self.topic_prefix + topic,
'payload': payload,
'retain': retain,
})

def publishing_handler(self):
def publishing_handler(self) -> None:
while True:
message: dict = self.publishing_queue.get()
message: dict[str, Any] = self.publishing_queue.get()
while not self.mqttc.is_connected():
time.sleep(1)
result = self.mqttc.publish(**message)
result: mqtt.MQTTMessageInfo = self.mqttc.publish(**message)
if result.rc != mqtt.MQTT_ERR_SUCCESS:
logging.error(f'mqtt publish failed: %s %s.', message, result)
self.publishing_queue.task_done()
68 changes: 35 additions & 33 deletions sungrowmodbus2mqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,107 +2,109 @@
import signal
import sys
import time
from typing import Any

from config import get_first_config
from modbus_handler import ModbusHandler
from mqtt_handler import MqttHandler

__version__ = '1.0.20'
__version__ = '1.0.21'


class SungrowModbus2Mqtt:
def __init__(self):
def __init__(self) -> None:
signal.signal(signal.SIGINT, self.exit_handler)
signal.signal(signal.SIGTERM, self.exit_handler)
config = get_first_config()
config: dict = get_first_config()
if 'logging' in config:
logging_level_name = config['logging'].upper()
logging_level = logging.getLevelNamesMapping().get(logging_level_name, logging.NOTSET)
logging_level_name: str = config['logging'].upper()
logging_level: int = logging.getLevelNamesMapping().get(logging_level_name, logging.NOTSET)
if logging_level != logging.NOTSET:
logging.getLogger().setLevel(logging_level)
else:
logging.warning(f'unknown logging level: %s.', logging_level)
self.mqtt_handler = MqttHandler(config)
self.modbus_handler = ModbusHandler(config)
self.mqtt_handler: MqttHandler = MqttHandler(config)
self.modbus_handler: ModbusHandler = ModbusHandler(config)
self.modbus_handler.reconnect(first_connect=True)
self.address_offset: int = config.get('address_offset', 0)
self.old_value_map: bool = config.get('old_value_map', False)
self.scan_batching: int = config.get('scan_batching', 100)
self.update_rate: int = config.get('update_rate', 2)
self.registers = {
self.registers: dict[str, dict[int, dict[str, Any]]] = {
'holding': {},
'input': {},
}
self.init_registers(config)

def loop(self):
def loop(self) -> None:
while True:
start_time = time.perf_counter()
start_time: float = time.perf_counter()
self.read(start_time)
self.publish()
time_taken = time.perf_counter() - start_time
time_to_sleep = self.update_rate - time_taken
time_taken: float = time.perf_counter() - start_time
time_to_sleep: float = self.update_rate - time_taken
logging.debug('looped in %.2fms, sleeping %.2fs.', time_taken * 1000, time_to_sleep)
if time_to_sleep > 0:
time.sleep(time_to_sleep)

def exit_handler(self, signum, frame):
def exit_handler(self, signum, frame) -> None:
self.modbus_handler.close()
sys.exit(0)

def add_dummy_register(self, register_table: str, address: int):
def add_dummy_register(self, register_table: str, address: int) -> None:
self.registers[register_table][address] = {'type': 'dummy'}

def create_register(self, register_table: str, config_register: dict) -> dict:
register = {
def create_register(self, register_table: str, config_register: dict) -> dict[str, Any]:
register: dict[str, Any] = {
'topic': config_register['pub_topic'],
'type': config_register.get('type', 'uint16').strip().lower(),
}
if 'value_map' in config_register:
value_map = config_register['value_map']
value_map: dict = config_register['value_map']
if self.old_value_map:
value_map = {v: k for k, v in value_map.items()}
register['map'] = value_map
for option in ['scale', 'mask', 'shift', 'retain']:
if option in config_register:
register[option] = config_register[option]
word_count = ModbusHandler.WORD_COUNT.get(register['type'], 1)
word_count: int = ModbusHandler.WORD_COUNT.get(register['type'], 1)
for i in range(1, word_count):
self.add_dummy_register(register_table, config_register['address'] + self.address_offset + i)
return register

def init_register(self, register_table: str, register: dict):
new_register = self.create_register(register_table, register)
def init_register(self, register_table: str, register: dict) -> None:
new_register: dict[str, Any] = self.create_register(register_table, register)
register_address: int = register['address'] + self.address_offset
existing_register = self.registers[register_table].setdefault(register_address, new_register)
existing_register: dict[str, Any] = self.registers[register_table].setdefault(register_address, new_register)
if existing_register is not new_register:
existing_register.setdefault('multi', []).append(new_register)

def init_registers(self, config: dict):
def init_registers(self, config: dict) -> None:
for register_type in ['registers', 'input', 'holding']:
for register in config.get(register_type, []):
register_table = register.get('table', 'holding') if register_type == 'registers' else register_type
register_table: str = register.get('table',
'holding') if register_type == 'registers' else register_type
self.init_register(register_table, register)
self.registers = {table: dict(sorted(register.items())) for table, register in self.registers.items()}

def read(self, start_time: float):
def read(self, start_time: float) -> None:
for table, table_registers in self.registers.items():
for address, register in list(table_registers.items()):
if start_time - register.get('last_fetch', 0) < self.update_rate - 0.001:
continue

count = register.get('read_count', self.scan_batching)
count: int = register.get('read_count', self.scan_batching)
if 'read_count' not in register:
count = next((i + 1 for i in range(count - 1, -1, -1) if address + i in table_registers))
register['read_count'] = count

logging.debug(f'read: table:%s address:%s count:%s.', table, address, count)
result = self.modbus_handler.read(table, address, count)
result: list[int] = self.modbus_handler.read(table, address, count)

for result_address, result_register in enumerate(result, start=address):
if result_address not in table_registers:
continue
table_register = table_registers[result_address]
table_register: dict[str, Any] = table_registers[result_address]
table_register['last_fetch'] = start_time
if table_register.get('value') == result_register:
table_register['new'] = False
Expand All @@ -111,27 +113,27 @@ def read(self, start_time: float):
table_register['new'] = True

@staticmethod
def prepare_value(register: dict, value: int) -> str | int | float:
def prepare_value(register: dict[str, Any], value: int) -> str | int | float:
if value_map := register.get('map'):
return value_map.get(value, f'{value:#x} not mapped!')
if mask := register.get('mask'):
value &= mask
if shift := register.get('shift'):
value >>= shift
if scale := register.get('scale'):
value = round(value * scale, 10)
value: int | float = round(value * scale, 10)
return value

def publish(self):
def publish(self) -> None:
for table, table_registers in self.registers.items():
for address, register in table_registers.items():
if (register_type := register['type']) == 'dummy':
continue
word_count = ModbusHandler.WORD_COUNT.get(register_type, 1)
word_count: int = ModbusHandler.WORD_COUNT.get(register_type, 1)
if not any(table_registers[address + i].get('new', False) for i in range(word_count)):
continue
values: list[int] = [table_registers[address + i]['value'] for i in range(word_count)]
value = self.modbus_handler.decode(values, register_type)
value: int = self.modbus_handler.decode(values, register_type)
for subregister in register.get('multi', []):
self.mqtt_handler.publish(subregister['topic'], self.prepare_value(subregister, value),
subregister.get('retain', False))
Expand All @@ -143,5 +145,5 @@ def publish(self):
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logging.getLogger('pymodbus').setLevel(logging.INFO)
logging.info(f'starting SungrowModbus2Mqtt v%s.', __version__)
app = SungrowModbus2Mqtt()
app: SungrowModbus2Mqtt = SungrowModbus2Mqtt()
app.loop()

0 comments on commit a785396

Please sign in to comment.