diff --git a/pyproject.toml b/pyproject.toml index 1868f93..9c17b46 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ license-files = ["LICEN[CS]E*"] dependencies = [ "bleak", "pyserial-asyncio", "pycayennelpp" ] [project.optional-dependencies] -dev = ["pytest", "pytest-asyncio"] +dev = ["pytest", "pytest-asyncio", "black", "ruff"] [project.urls] Homepage = "https://github.com/fdlamotte/meshcore_py" diff --git a/src/meshcore/__init__.py b/src/meshcore/__init__.py index f64d872..c58f593 100644 --- a/src/meshcore/__init__.py +++ b/src/meshcore/__init__.py @@ -1,11 +1,23 @@ +"""A library for communicating with meshcore devices.""" import logging +from .ble_cx import BLEConnection +from .connection_manager import ConnectionManager +from .events import EventType +from .meshcore import MeshCore +from .serial_cx import SerialConnection +from .tcp_cx import TCPConnection + # Setup default logger logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) -from meshcore.events import EventType -from meshcore.meshcore import MeshCore, logger -from meshcore.connection_manager import ConnectionManager -from meshcore.tcp_cx import TCPConnection -from meshcore.ble_cx import BLEConnection -from meshcore.serial_cx import SerialConnection +__all__ = [ + "BLEConnection", + "ConnectionManager", + "EventType", + "MeshCore", + "SerialConnection", + "TCPConnection", + "logger", +] diff --git a/src/meshcore/binary_commands.py b/src/meshcore/binary_commands.py deleted file mode 100644 index 95e700b..0000000 --- a/src/meshcore/binary_commands.py +++ /dev/null @@ -1,116 +0,0 @@ -import asyncio -import logging -from enum import Enum -import json -from .events import Event, EventType -from cayennelpp import LppFrame, LppData -from cayennelpp.lpp_type import LppType -from meshcore.lpp_json_encoder import lpp_json_encoder, my_lpp_types, lpp_format_val - -logger = logging.getLogger("meshcore") - -class BinaryReqType(Enum): - TELEMETRY = 0x03 - MMA = 0x04 - ACL = 0x05 - -def lpp_parse(buf): - """Parse a given byte string and return as a LppFrame object.""" - i = 0 - lpp_data_list = [] - while i < len(buf) and buf[i] != 0: - lppdata = LppData.from_bytes(buf[i:]) - lpp_data_list.append(lppdata) - i = i + len(lppdata) - - return json.loads(json.dumps(LppFrame(lpp_data_list), default=lpp_json_encoder)) - -def lpp_parse_mma (buf): - i = 0 - res = [] - while i < len(buf) and buf[i] != 0: - chan = buf[i] - i = i + 1 - type = buf[i] - lpp_type = LppType.get_lpp_type(type) - size = lpp_type.size - i = i + 1 - min = lpp_format_val(lpp_type, lpp_type.decode(buf[i:i+size])) - i = i + size - max = lpp_format_val(lpp_type, lpp_type.decode(buf[i:i+size])) - i = i + size - avg = lpp_format_val(lpp_type, lpp_type.decode(buf[i:i+size])) - i = i + size - res.append({"channel":chan, - "type":my_lpp_types[type][0], - "min":min, - "max":max, - "avg":avg, - }) - return res - -def parse_acl (buf): - i = 0 - res = [] - while i + 7 <= len(buf): - key = buf[i:i+6].hex() - perm = buf[i+6] - if (key != "000000000000"): - res.append({"key": key, "perm": perm}) - i = i + 7 - return res - -class BinaryCommandHandler : - """ Helper functions to handle binary requests through binary commands """ - def __init__ (self, c): - self.commands = c - - @property - def dispatcher(self): - return self.commands.dispatcher - - async def req_binary (self, contact, request, timeout=0) : - res = await self.commands.send_binary_req(contact, request) - logger.debug(res) - if res.type == EventType.ERROR: - logger.error(f"Error while requesting binary data") - return None - else: - exp_tag = res.payload["expected_ack"].hex() - timeout = res.payload["suggested_timeout"]/800 if timeout == 0 else timeout - res2 = await self.dispatcher.wait_for_event(EventType.BINARY_RESPONSE, attribute_filters={"tag": exp_tag}, timeout=timeout) - logger.debug(res2) - if res2 is None : - return None - else: - return res2.payload - - async def req_telemetry (self, contact, timeout=0) : - code = BinaryReqType.TELEMETRY.value - req = code.to_bytes(1, 'little', signed=False) - res = await self.req_binary(contact, req, timeout) - if (res is None) : - return None - else: - return lpp_parse(bytes.fromhex(res["data"])) - - async def req_mma (self, contact, start, end, timeout=0) : - code = BinaryReqType.MMA.value - req = code.to_bytes(1, 'little', signed=False)\ - + start.to_bytes(4, 'little', signed = False)\ - + end.to_bytes(4, 'little', signed=False)\ - + b"\0\0" - res = await self.req_binary(contact, req, timeout) - if (res is None) : - return None - else: - return lpp_parse_mma(bytes.fromhex(res["data"])[4:]) - - async def req_acl (self, contact, timeout=0) : - code = BinaryReqType.ACL.value - req = code.to_bytes(1, 'little', signed=False) + b"\0\0" - res = await self.req_binary(contact, req, timeout) - if (res is None) : - return None - else: - return parse_acl(bytes.fromhex(res['data'])) diff --git a/src/meshcore/ble_cx.py b/src/meshcore/ble_cx.py index e480546..517ef1b 100644 --- a/src/meshcore/ble_cx.py +++ b/src/meshcore/ble_cx.py @@ -1,27 +1,29 @@ -""" - mccli.py : CLI interface to MeschCore BLE companion app """ +mccli.py : CLI interface to MeschCore BLE companion app +""" + import asyncio import logging -# Get logger -logger = logging.getLogger("meshcore") - from bleak import BleakClient, BleakScanner from bleak.backends.characteristic import BleakGATTCharacteristic from bleak.backends.device import BLEDevice from bleak.backends.scanner import AdvertisementData from bleak.exc import BleakDeviceNotFoundError +# Get logger +logger = logging.getLogger("meshcore") + UART_SERVICE_UUID = "6E400001-B5A3-F393-E0A9-E50E24DCCA9E" UART_RX_CHAR_UUID = "6E400002-B5A3-F393-E0A9-E50E24DCCA9E" UART_TX_CHAR_UUID = "6E400003-B5A3-F393-E0A9-E50E24DCCA9E" + class BLEConnection: def __init__(self, address=None, client=None): """ Constructor: specify address or an existing BleakClient. - + Args: address (str, optional): The Bluetooth address of the device. client (BleakClient, optional): An existing BleakClient instance. @@ -48,7 +50,9 @@ class BLEConnection: logger.debug("Using pre-configured BleakClient.") # If a client is already provided, ensure its disconnect callback is set self.client._disconnected_callback = self.handle_disconnect + self.address = self.client.address else: + def match_meshcore_device(_: BLEDevice, adv: AdvertisementData): """Filter to match MeshCore devices.""" if adv.local_name and adv.local_name.startswith("MeshCore"): @@ -63,10 +67,14 @@ class BLEConnection: logger.warning("No MeshCore device found during scan.") return None logger.info(f"Found device: {device}") - self.client = BleakClient(device, disconnected_callback=self.handle_disconnect) + self.client = BleakClient( + device, disconnected_callback=self.handle_disconnect + ) self.address = self.client.address else: - self.client = BleakClient(self.address, disconnected_callback=self.handle_disconnect) + self.client = BleakClient( + self.address, disconnected_callback=self.handle_disconnect + ) try: await self.client.connect() @@ -87,24 +95,26 @@ class BLEConnection: return self.address def handle_disconnect(self, client: BleakClient): - """ Callback to handle disconnection """ - logger.debug(f"BLE device disconnected: {client.address} (is_connected: {client.is_connected})") + """Callback to handle disconnection""" + logger.debug( + f"BLE device disconnected: {client.address} (is_connected: {client.is_connected})" + ) # Reset the address we found to what user specified # this allows to reconnect to the same device self.address = self._user_provided_address - + if self._disconnect_callback: asyncio.create_task(self._disconnect_callback("ble_disconnect")) - + def set_disconnect_callback(self, callback): """Set callback to handle disconnections.""" self._disconnect_callback = callback - def set_reader(self, reader) : + def set_reader(self, reader): self.reader = reader def handle_rx(self, _: BleakGATTCharacteristic, data: bytearray): - if not self.reader is None: + if self.reader is not None: asyncio.create_task(self.reader.handle_rx(data)) async def send(self, data): @@ -114,8 +124,8 @@ class BLEConnection: if not self.rx_char: logger.error("RX characteristic not found") return False - await self.client.write_gatt_char(self.rx_char, bytes(data), response=False) - + await self.client.write_gatt_char(self.rx_char, bytes(data), response=True) + async def disconnect(self): """Disconnect from the BLE device.""" if self.client and self.client.is_connected: diff --git a/src/meshcore/commands.py b/src/meshcore/commands.py deleted file mode 100644 index 2d5bbf1..0000000 --- a/src/meshcore/commands.py +++ /dev/null @@ -1,491 +0,0 @@ -import asyncio -import logging -import random -from typing import Any, Dict, List, Optional, Union -from .events import Event, EventType -from .binary_commands import BinaryCommandHandler - -# Define types for destination parameters -DestinationType = Union[bytes, str, Dict[str, Any]] - -logger = logging.getLogger("meshcore") - -def _validate_destination(dst: DestinationType, prefix_length: int = 6) -> bytes: - """ - Validates and converts a destination to a bytes object. - - Args: - dst: The destination, which can be: - - str: Hex string representation of a public key - - dict: Contact object with a "public_key" field - prefix_length: The length of the prefix to use (default: 6 bytes) - - Returns: - bytes: The destination public key as a bytes object - - Raises: - ValueError: If dst is invalid or doesn't contain required fields - """ - if isinstance(dst, bytes): - # Already bytes, use directly - return dst[:prefix_length] - elif isinstance(dst, str): - # Hex string, convert to bytes - try: - return bytes.fromhex(dst)[:prefix_length] - except ValueError: - raise ValueError(f"Invalid public key hex string: {dst}") - elif isinstance(dst, dict): - # Contact object, extract public_key - if "public_key" not in dst: - raise ValueError("Contact object must have a 'public_key' field") - try: - return bytes.fromhex(dst["public_key"])[:prefix_length] - except ValueError: - raise ValueError(f"Invalid public_key in contact: {dst['public_key']}") - else: - raise ValueError(f"Destination must be a public key string or contact object, got: {type(dst)}") - -class CommandHandler: - DEFAULT_TIMEOUT = 5.0 - - def __init__(self, default_timeout: Optional[float] = None): - self._sender_func = None - self._reader = None - self.dispatcher = None - self.binary = BinaryCommandHandler(self) - self.default_timeout = default_timeout if default_timeout is not None else self.DEFAULT_TIMEOUT - - def set_connection(self, connection: Any) -> None: - async def sender(data: bytes) -> None: - await connection.send(data) - self._sender_func = sender - - def set_reader(self, reader: Any) -> None: - self._reader = reader - - def set_dispatcher(self, dispatcher: Any) -> None: - self.dispatcher = dispatcher - - async def send(self, data: bytes, expected_events: Optional[Union[EventType, List[EventType]]] = None, - timeout: Optional[float] = None) -> Event: - """ - Send a command and wait for expected event responses. - - Args: - data: The data to send - expected_events: EventType or list of EventTypes to wait for - timeout: Timeout in seconds, or None to use default_timeout - - Returns: - Event: The full event object that was received in response to the command - """ - if not self.dispatcher: - raise RuntimeError("Dispatcher not set, cannot send commands") - - # Use the provided timeout or fall back to default_timeout - timeout = timeout if timeout is not None else self.default_timeout - - if self._sender_func: - logger.debug(f"Sending raw data: {data.hex() if isinstance(data, bytes) else data}") - await self._sender_func(data) - - if expected_events: - try: - # Convert single event to list if needed - if not isinstance(expected_events, list): - expected_events = [expected_events] - - logger.debug(f"Waiting for events {expected_events}, timeout={timeout}") - - # Create futures for all expected events - futures = [] - for event_type in expected_events: - future = asyncio.create_task( - self.dispatcher.wait_for_event(event_type, {}, timeout) - ) - futures.append(future) - - # Wait for the first event to complete or all to timeout - done, pending = await asyncio.wait( - futures, - timeout=timeout, - return_when=asyncio.FIRST_COMPLETED - ) - - # Cancel all pending futures - for future in pending: - future.cancel() - - # Check if any future completed successfully - for future in done: - event = await future - if event: - return event - - # Create an error event when no event is received - return Event(EventType.ERROR, {"reason": "no_event_received"}) - except asyncio.TimeoutError: - logger.debug(f"Command timed out {data}") - return Event(EventType.ERROR, {"reason": "timeout"}) - except Exception as e: - logger.debug(f"Command error: {e}") - return Event(EventType.ERROR, {"error": str(e)}) - # For commands that don't expect events, return a success event - return Event(EventType.OK, {}) - - - async def send_appstart(self) -> Event: - logger.debug("Sending appstart command") - b1 = bytearray(b'\x01\x03 mccli') - return await self.send(b1, [EventType.SELF_INFO]) - - async def send_device_query(self) -> Event: - logger.debug("Sending device query command") - return await self.send(b"\x16\x03", [EventType.DEVICE_INFO, EventType.ERROR]) - - async def send_advert(self, flood: bool = False) -> Event: - logger.debug(f"Sending advertisement command (flood={flood})") - if flood: - return await self.send(b"\x07\x01", [EventType.OK, EventType.ERROR]) - else: - return await self.send(b"\x07", [EventType.OK, EventType.ERROR]) - - async def set_name(self, name: str) -> Event: - logger.debug(f"Setting device name to: {name}") - return await self.send(b'\x08' + name.encode("utf-8"), [EventType.OK, EventType.ERROR]) - - async def set_coords(self, lat: float, lon: float) -> Event: - logger.debug(f"Setting coordinates to: lat={lat}, lon={lon}") - return await self.send(b'\x0e'\ - + int(lat*1e6).to_bytes(4, 'little', signed=True)\ - + int(lon*1e6).to_bytes(4, 'little', signed=True)\ - + int(0).to_bytes(4, 'little'), [EventType.OK, EventType.ERROR]) - - async def reboot(self) -> Event: - logger.debug("Sending reboot command") - return await self.send(b'\x13reboot') - - async def get_bat(self) -> Event: - logger.debug("Getting battery information") - return await self.send(b'\x14', [EventType.BATTERY, EventType.ERROR]) - - async def get_time(self) -> Event: - logger.debug("Getting device time") - return await self.send(b"\x05", [EventType.CURRENT_TIME, EventType.ERROR]) - - async def set_time(self, val: int) -> Event: - logger.debug(f"Setting device time to: {val}") - return await self.send(b"\x06" + int(val).to_bytes(4, 'little'), [EventType.OK, EventType.ERROR]) - - async def set_tx_power(self, val: int) -> Event: - logger.debug(f"Setting TX power to: {val}") - return await self.send(b"\x0c" + int(val).to_bytes(4, 'little'), [EventType.OK, EventType.ERROR]) - - async def set_radio(self, freq: float, bw: float, sf: int, cr: int) -> Event: - logger.debug(f"Setting radio params: freq={freq}, bw={bw}, sf={sf}, cr={cr}") - return await self.send(b"\x0b" \ - + int(float(freq)*1000).to_bytes(4, 'little')\ - + int(float(bw)*1000).to_bytes(4, 'little')\ - + int(sf).to_bytes(1, 'little')\ - + int(cr).to_bytes(1, 'little'), [EventType.OK, EventType.ERROR]) - - async def set_tuning(self, rx_dly: int, af: int) -> Event: - logger.debug(f"Setting tuning params: rx_dly={rx_dly}, af={af}") - return await self.send(b"\x15" \ - + int(rx_dly).to_bytes(4, 'little')\ - + int(af).to_bytes(4, 'little')\ - + int(0).to_bytes(1, 'little')\ - + int(0).to_bytes(1, 'little'), [EventType.OK, EventType.ERROR]) - - async def set_other_params(self, manual_add_contacts : bool, telemetry_mode_base : int, telemetry_mode_loc : int, telemetry_mode_env : int, advert_loc_policy : int) : - telemetry_mode = (telemetry_mode_base & 0b11) | ((telemetry_mode_loc & 0b11) << 2) | ((telemetry_mode_env & 0b11) << 4) - data = b"\x26" + manual_add_contacts.to_bytes(1) + telemetry_mode.to_bytes(1) + advert_loc_policy.to_bytes(1) - return await self.send(data, [EventType.OK, EventType.ERROR]) - - async def set_telemetry_mode_base(self, telemetry_mode_base : int) : - infos = (await self.send_appstart()).payload - return await self.set_other_params( - infos["manual_add_contacts"], - telemetry_mode_base, - infos["telemetry_mode_loc"], - infos["telemetry_mode_env"], - infos["adv_loc_policy"]) - - async def set_telemetry_mode_loc(self, telemetry_mode_loc : int) : - infos = (await self.send_appstart()).payload - return await self.set_other_params( - infos["manual_add_contacts"], - infos["telemetry_mode_base"], - telemetry_mode_loc, - infos["telemetry_mode_env"], - infos["adv_loc_policy"]) - - async def set_telemetry_mode_env(self, telemetry_mode_env : int) : - infos = (await self.send_appstart()).payload - return await self.set_other_params( - infos["manual_add_contacts"], - infos["telemetry_mode_base"], - infos["telemetry_mode_loc"], - telemetry_mode_env, - infos["adv_loc_policy"]) - - async def set_manual_add_contacts(self, manual_add_contacts:bool) : - infos = (await self.send_appstart()).payload - return await self.set_other_params( - manual_add_contacts, - infos["telemetry_mode_base"], - infos["telemetry_mode_loc"], - infos["telemetry_mode_env"], - infos["adv_loc_policy"]) - - async def set_advert_loc_policy(self, advert_loc_policy:int) : - infos = (await self.send_appstart()).payload - return await self.set_other_params( - infos["manual_add_contacts"], - infos["telemetry_mode_base"], - infos["telemetry_mode_loc"], - infos["telemetry_mode_env"], - advert_loc_policy) - - async def set_devicepin(self, pin: int) -> Event: - logger.debug(f"Setting device PIN to: {pin}") - return await self.send(b"\x25" \ - + int(pin).to_bytes(4, 'little'), [EventType.OK, EventType.ERROR]) - - async def get_contacts(self, lastmod=0) -> Event: - logger.debug("Getting contacts") - data=b"\x04" - if lastmod > 0: - data = data + lastmod.to_bytes(4, 'little') - return await self.send(data, [EventType.CONTACTS, EventType.ERROR]) - - async def reset_path(self, key: DestinationType) -> Event: - key_bytes = _validate_destination(key, prefix_length=32) - logger.debug(f"Resetting path for contact: {key_bytes.hex()}") - data = b"\x0D" + key_bytes - return await self.send(data, [EventType.OK, EventType.ERROR]) - - async def share_contact(self, key: DestinationType) -> Event: - key_bytes = _validate_destination(key, prefix_length=32) - logger.debug(f"Sharing contact: {key_bytes.hex()}") - data = b"\x10" + key_bytes - return await self.send(data, [EventType.OK, EventType.ERROR]) - - async def export_contact(self, key: Optional[DestinationType] = None) -> Event: - if key: - key_bytes = _validate_destination(key, prefix_length=32) - logger.debug(f"Exporting contact: {key_bytes.hex()}") - data = b"\x11" + key_bytes - else: - logger.debug("Exporting node") - data = b"\x11" - return await self.send(data, [EventType.CONTACT_URI, EventType.ERROR]) - - async def import_contact(self, card_data) -> Event: - data = b"\x12" + card_data - return await self.send(data, [EventType.OK, EventType.ERROR]) - - async def remove_contact(self, key: DestinationType) -> Event: - key_bytes = _validate_destination(key, prefix_length=32) - logger.debug(f"Removing contact: {key_bytes.hex()}") - data = b"\x0f" + key_bytes - return await self.send(data, [EventType.OK, EventType.ERROR]) - - async def update_contact (self, contact, path=None, flags=None) -> Event: - if path is None : - out_path_hex = contact["out_path"] - out_path_len = contact["out_path_len"] - else : - out_path_hex = path - out_path_len = int(len(path) / 2) - # reflect the change - contact["out_path"] = out_path_hex - contact["out_path_len"] = out_path_len - out_path_hex = out_path_hex + (128-len(out_path_hex)) * "0" - - if flags is None : - flags = contact["flags"] - else : - # reflect the change - contact["flags"] = flags - - adv_name_hex = contact["adv_name"].encode().hex() - adv_name_hex = adv_name_hex + (64-len(adv_name_hex)) * "0" - data = b"\x09" \ - + bytes.fromhex(contact["public_key"])\ - + contact["type"].to_bytes(1)\ - + flags.to_bytes(1)\ - + out_path_len.to_bytes(1, 'little', signed=True)\ - + bytes.fromhex(out_path_hex)\ - + bytes.fromhex(adv_name_hex)\ - + contact["last_advert"].to_bytes(4, 'little')\ - + int(contact["adv_lat"]*1e6).to_bytes(4, 'little', signed=True)\ - + int(contact["adv_lon"]*1e6).to_bytes(4, 'little', signed=True) - return await self.send(data, [EventType.OK, EventType.ERROR]) - - async def add_contact (self, contact) -> Event: - return await self.update_contact(contact) - - async def change_contact_path (self, contact, path) -> Event: - return await self.update_contact(contact, path) - - async def change_contact_flags (self, contact, flags) -> Event: - return await self.update_contact(contact, flags=flags) - - async def get_msg(self, timeout: Optional[float] = None) -> Event: - logger.debug("Requesting pending messages") - return await self.send(b"\x0A", [EventType.CONTACT_MSG_RECV, EventType.CHANNEL_MSG_RECV, EventType.ERROR, EventType.NO_MORE_MSGS], timeout) - - async def send_login(self, dst: DestinationType, pwd: str) -> Event: - dst_bytes = _validate_destination(dst, prefix_length=32) - logger.debug(f"Sending login request to: {dst_bytes.hex()}") - data = b"\x1a" + dst_bytes + pwd.encode("utf-8") - return await self.send(data, [EventType.MSG_SENT, EventType.ERROR]) - - async def send_logout(self, dst: DestinationType) -> Event: - dst_bytes = _validate_destination(dst, prefix_length=32) - self.login_resp = asyncio.Future() - data = b"\x1d" + dst_bytes - return await self.send(data, [EventType.OK, EventType.ERROR]) - - async def send_statusreq(self, dst: DestinationType) -> Event: - dst_bytes = _validate_destination(dst, prefix_length=32) - logger.debug(f"Sending status request to: {dst_bytes.hex()}") - data = b"\x1b" + dst_bytes - return await self.send(data, [EventType.MSG_SENT, EventType.ERROR]) - - async def send_cmd(self, dst: DestinationType, cmd: str, timestamp: Optional[int] = None) -> Event: - dst_bytes = _validate_destination(dst) - logger.debug(f"Sending command to {dst_bytes.hex()}: {cmd}") - - if timestamp is None: - import time - timestamp = int(time.time()) - - data = b"\x02\x01\x00" + timestamp.to_bytes(4, 'little') + dst_bytes + cmd.encode("utf-8") - return await self.send(data, [EventType.MSG_SENT, EventType.ERROR]) - - async def send_msg(self, dst: DestinationType, msg: str, timestamp: Optional[int] = None) -> Event: - dst_bytes = _validate_destination(dst) - logger.debug(f"Sending message to {dst_bytes.hex()}: {msg}") - - if timestamp is None: - import time - timestamp = int(time.time()) - - data = b"\x02\x00\x00" + timestamp.to_bytes(4, 'little') + dst_bytes + msg.encode("utf-8") - return await self.send(data, [EventType.MSG_SENT, EventType.ERROR]) - - async def send_chan_msg(self, chan, msg, timestamp=None) -> Event: - logger.debug(f"Sending channel message to channel {chan}: {msg}") - - # Default to current time if timestamp not provided - if timestamp is None: - import time - timestamp = int(time.time()).to_bytes(4, 'little') - - data = b"\x03\x00" + chan.to_bytes(1, 'little') + timestamp + msg.encode("utf-8") - return await self.send(data, [EventType.OK, EventType.ERROR]) - - async def send_telemetry_req(self, dst: DestinationType) -> Event : - dst_bytes = _validate_destination(dst, prefix_length=32) - logger.debug(f"Asking telemetry to {dst_bytes.hex()}") - data = b"\x27\x00\x00\x00" + dst_bytes - return await self.send(data, [EventType.MSG_SENT, EventType.ERROR]) - - async def send_binary_req(self, dst: DestinationType, bin_data) -> Event : - dst_bytes = _validate_destination(dst, prefix_length=32) - logger.debug(f"Binary request to {dst_bytes.hex()}") - data = b"\x32" + dst_bytes + bin_data - return await self.send(data, [EventType.MSG_SENT, EventType.ERROR]) - - async def send_path_discovery(self, dst: DestinationType) -> Event : - dst_bytes = _validate_destination(dst, prefix_length=32) - logger.debug(f"Path discovery request for {dst_bytes.hex()}") - data = b"\x34\x00" + dst_bytes - return await self.send(data, [EventType.MSG_SENT, EventType.ERROR]) - - async def get_self_telemetry(self) -> Event : - logger.debug(f"Getting self telemetry") - data = b"\x27\x00\x00\x00" - return await self.send(data, [EventType.TELEMETRY_RESPONSE, EventType.ERROR]) - - async def get_custom_vars(self) -> Event: - logger.debug(f"Asking for custom vars") - data = b"\x28" - return await self.send(data, [EventType.CUSTOM_VARS, EventType.ERROR]) - - async def set_custom_var(self, key, value) -> Event: - logger.debug(f"Setting custom var {key} to {value}") - data = b"\x29" + key.encode("utf-8") + b":" + value.encode("utf-8") - return await self.send(data, [EventType.OK, EventType.ERROR]) - - async def get_channel(self, channel_idx: int) -> Event: - logger.debug(f"Getting channel info for channel {channel_idx}") - data = b"\x1f" + channel_idx.to_bytes(1, 'little') - return await self.send(data, [EventType.CHANNEL_INFO, EventType.ERROR]) - - async def set_channel(self, channel_idx: int, channel_name: str, channel_secret: bytes) -> Event: - logger.debug(f"Setting channel {channel_idx}: name={channel_name}") - - # Pad channel name to 32 bytes - name_bytes = channel_name.encode('utf-8')[:32] - name_bytes = name_bytes.ljust(32, b'\x00') - - # Ensure channel secret is exactly 16 bytes - if len(channel_secret) != 16: - raise ValueError("Channel secret must be exactly 16 bytes") - - data = b"\x20" + channel_idx.to_bytes(1, 'little') + name_bytes + channel_secret - return await self.send(data, [EventType.OK, EventType.ERROR]) - - async def send_trace(self, auth_code: int = 0, tag: Optional[int] = None, - flags: int = 0, path: Optional[Union[str, bytes, bytearray]] = None) -> Event: - """ - Send a trace packet to test routing through specific repeaters - - Args: - auth_code: 32-bit authentication code (default: 0) - tag: 32-bit integer to identify this trace (default: random) - flags: 8-bit flags field (default: 0) - path: Optional string with comma-separated hex values representing repeater pubkeys (e.g. "23,5f,3a") - or a bytes/bytearray object with the raw path data - - Returns: - Event object with sent status, tag, and estimated timeout in milliseconds - """ - # Generate random tag if not provided - if tag is None: - tag = random.randint(1, 0xFFFFFFFF) - if auth_code is None: - auth_code = random.randint(1, 0xFFFFFFFF) - - logger.debug(f"Sending trace: tag={tag}, auth={auth_code}, flags={flags}, path={path}") - - # Prepare the command packet: CMD(1) + tag(4) + auth_code(4) + flags(1) + [path] - cmd_data = bytearray([36]) # CMD_SEND_TRACE_PATH - cmd_data.extend(tag.to_bytes(4, 'little')) - cmd_data.extend(auth_code.to_bytes(4, 'little')) - cmd_data.append(flags) - - # Process path if provided - if path: - if isinstance(path, str): - # Convert comma-separated hex values to bytes - try: - path_bytes = bytearray() - for hex_val in path.split(','): - hex_val = hex_val.strip() - path_bytes.append(int(hex_val, 16)) - cmd_data.extend(path_bytes) - except ValueError as e: - logger.error(f"Invalid path format: {e}") - return Event(EventType.ERROR, {"reason": "invalid_path_format"}) - elif isinstance(path, (bytes, bytearray)): - cmd_data.extend(path) - else: - logger.error(f"Unsupported path type: {type(path)}") - return Event(EventType.ERROR, {"reason": "unsupported_path_type"}) - - return await self.send(cmd_data, [EventType.MSG_SENT, EventType.ERROR]) diff --git a/src/meshcore/commands/__init__.py b/src/meshcore/commands/__init__.py new file mode 100644 index 0000000..0638847 --- /dev/null +++ b/src/meshcore/commands/__init__.py @@ -0,0 +1,18 @@ +from typing import Any, Optional + +from ..events import EventDispatcher +from ..reader import MessageReader +from .base import CommandHandlerBase +from .binary import BinaryCommandHandler +from .contact import ContactCommands +from .device import DeviceCommands +from .messaging import MessagingCommands + + +class CommandHandler( + DeviceCommands, ContactCommands, MessagingCommands, BinaryCommandHandler +): + pass + + +__all__ = ["CommandHandler"] diff --git a/src/meshcore/commands/base.py b/src/meshcore/commands/base.py new file mode 100644 index 0000000..ba6faad --- /dev/null +++ b/src/meshcore/commands/base.py @@ -0,0 +1,146 @@ +import asyncio +import logging +import random +from typing import Any, Callable, Coroutine, Dict, List, Optional, Union + +from ..events import Event, EventDispatcher, EventType +from ..reader import MessageReader + +# Define types for destination parameters +DestinationType = Union[bytes, str, Dict[str, Any]] + +logger = logging.getLogger("meshcore") + + +def _validate_destination(dst: DestinationType, prefix_length: int = 6) -> bytes: + """ + Validates and converts a destination to a bytes object. + + Args: + dst: The destination, which can be: + - str: Hex string representation of a public key + - dict: Contact object with a "public_key" field + prefix_length: The length of the prefix to use (default: 6 bytes) + + Returns: + bytes: The destination public key as a bytes object + + Raises: + ValueError: If dst is invalid or doesn't contain required fields + """ + if isinstance(dst, bytes): + # Already bytes, use directly + return dst[:prefix_length] + elif isinstance(dst, str): + # Hex string, convert to bytes + try: + return bytes.fromhex(dst)[:prefix_length] + except ValueError: + raise ValueError(f"Invalid public key hex string: {dst}") + elif isinstance(dst, dict): + # Contact object, extract public_key + if "public_key" not in dst: + raise ValueError("Contact object must have a 'public_key' field") + try: + return bytes.fromhex(dst["public_key"])[:prefix_length] + except ValueError: + raise ValueError(f"Invalid public_key in contact: {dst['public_key']}") + else: + raise ValueError( + f"Destination must be a public key string or contact object, got: {type(dst)}" + ) + + +class CommandHandlerBase: + DEFAULT_TIMEOUT = 5.0 + + def __init__(self, default_timeout: Optional[float] = None): + self._sender_func: Optional[Callable[[bytes], Coroutine[Any, Any, None]]] = None + self._reader: Optional[MessageReader] = None + self.dispatcher: Optional[EventDispatcher] = None + self.default_timeout = ( + default_timeout if default_timeout is not None else self.DEFAULT_TIMEOUT + ) + + def set_connection(self, connection: Any) -> None: + async def sender(data: bytes) -> None: + await connection.send(data) + + self._sender_func = sender + + def set_reader(self, reader: MessageReader) -> None: + self._reader = reader + + def set_dispatcher(self, dispatcher: EventDispatcher) -> None: + self.dispatcher = dispatcher + + async def send( + self, + data: bytes, + expected_events: Optional[Union[EventType, List[EventType]]] = None, + timeout: Optional[float] = None, + ) -> Event: + """ + Send a command and wait for expected event responses. + + Args: + data: The data to send + expected_events: EventType or list of EventTypes to wait for + timeout: Timeout in seconds, or None to use default_timeout + + Returns: + Event: The full event object that was received in response to the command + """ + if not self.dispatcher: + raise RuntimeError("Dispatcher not set, cannot send commands") + + # Use the provided timeout or fall back to default_timeout + timeout = timeout if timeout is not None else self.default_timeout + + if self._sender_func: + logger.debug( + f"Sending raw data: {data.hex() if isinstance(data, bytes) else data}" + ) + await self._sender_func(data) + + if expected_events: + try: + # Convert single event to list if needed + if not isinstance(expected_events, list): + expected_events = [expected_events] + + logger.debug(f"Waiting for events {expected_events}, timeout={timeout}") + + # Create futures for all expected events + futures = [] + for event_type in expected_events: + future = asyncio.create_task( + self.dispatcher.wait_for_event(event_type, {}, timeout) + ) + futures.append(future) + + # Wait for the first event to complete or all to timeout + done, pending = await asyncio.wait( + futures, timeout=timeout, return_when=asyncio.FIRST_COMPLETED + ) + + # Cancel all pending futures + for future in pending: + future.cancel() + + # Check if any future completed successfully + for future in done: + event = await future + if event: + return event + + # Create an error event when no event is received + return Event(EventType.ERROR, {"reason": "no_event_received"}) + except asyncio.TimeoutError: + logger.debug(f"Command timed out {data}") + return Event(EventType.ERROR, {"reason": "timeout"}) + except Exception as e: + logger.debug(f"Command error: {e}") + return Event(EventType.ERROR, {"error": str(e)}) + # For commands that don't expect events, return a success event + return Event(EventType.OK, {}) diff --git a/src/meshcore/commands/binary.py b/src/meshcore/commands/binary.py new file mode 100644 index 0000000..eb1165c --- /dev/null +++ b/src/meshcore/commands/binary.py @@ -0,0 +1,126 @@ +import logging +from enum import Enum +import json +from .base import CommandHandlerBase +from ..events import EventType +from cayennelpp import LppFrame, LppData +from cayennelpp.lpp_type import LppType +from ..lpp_json_encoder import lpp_json_encoder, my_lpp_types, lpp_format_val + +logger = logging.getLogger("meshcore") + + +class BinaryReqType(Enum): + TELEMETRY = 0x03 + MMA = 0x04 + ACL = 0x05 + + +def lpp_parse(buf): + """Parse a given byte string and return as a LppFrame object.""" + i = 0 + lpp_data_list = [] + while i < len(buf) and buf[i] != 0: + lppdata = LppData.from_bytes(buf[i:]) + lpp_data_list.append(lppdata) + i = i + len(lppdata) + + return json.loads(json.dumps(LppFrame(lpp_data_list), default=lpp_json_encoder)) + + +def lpp_parse_mma(buf): + i = 0 + res = [] + while i < len(buf) and buf[i] != 0: + chan = buf[i] + i = i + 1 + type = buf[i] + lpp_type = LppType.get_lpp_type(type) + size = lpp_type.size + i = i + 1 + min = lpp_format_val(lpp_type, lpp_type.decode(buf[i : i + size])) + i = i + size + max = lpp_format_val(lpp_type, lpp_type.decode(buf[i : i + size])) + i = i + size + avg = lpp_format_val(lpp_type, lpp_type.decode(buf[i : i + size])) + i = i + size + res.append( + { + "channel": chan, + "type": my_lpp_types[type][0], + "min": min, + "max": max, + "avg": avg, + } + ) + return res + + +def parse_acl(buf): + i = 0 + res = [] + while i + 7 <= len(buf): + key = buf[i : i + 6].hex() + perm = buf[i + 6] + if key != "000000000000": + res.append({"key": key, "perm": perm}) + i = i + 7 + return res + + +class BinaryCommandHandler(CommandHandlerBase): + """Helper functions to handle binary requests through binary commands""" + + async def req_binary(self, contact, request, timeout=0): + res = await self.send_binary_req(contact, request) + logger.debug(res) + if res.type == EventType.ERROR: + logger.error("Error while requesting binary data") + return None + else: + exp_tag = res.payload["expected_ack"].hex() + timeout = ( + res.payload["suggested_timeout"] / 800 if timeout == 0 else timeout + ) + res2 = await self.dispatcher.wait_for_event( + EventType.BINARY_RESPONSE, + attribute_filters={"tag": exp_tag}, + timeout=timeout, + ) + logger.debug(res2) + if res2 is None: + return None + else: + return res2.payload + + async def req_telemetry(self, contact, timeout=0): + code = BinaryReqType.TELEMETRY.value + req = code.to_bytes(1, "little", signed=False) + res = await self.req_binary(contact, req, timeout) + if res is None: + return None + else: + return lpp_parse(bytes.fromhex(res["data"])) + + async def req_mma(self, contact, start, end, timeout=0): + code = BinaryReqType.MMA.value + req = ( + code.to_bytes(1, "little", signed=False) + + start.to_bytes(4, "little", signed=False) + + end.to_bytes(4, "little", signed=False) + + b"\0\0" + ) + res = await self.req_binary(contact, req, timeout) + if res is None: + return None + else: + return lpp_parse_mma(bytes.fromhex(res["data"])[4:]) + + async def req_acl(self, contact, timeout=0): + code = BinaryReqType.ACL.value + req = code.to_bytes(1, "little", signed=False) + b"\0\0" + res = await self.req_binary(contact, req, timeout) + if res is None: + return None + else: + return parse_acl(bytes.fromhex(res["data"])) diff --git a/src/meshcore/commands/contact.py b/src/meshcore/commands/contact.py new file mode 100644 index 0000000..ca26eca --- /dev/null +++ b/src/meshcore/commands/contact.py @@ -0,0 +1,91 @@ +import logging +from typing import Optional + +from ..events import Event, EventType +from .base import CommandHandlerBase, DestinationType, _validate_destination + +logger = logging.getLogger("meshcore") + + +class ContactCommands(CommandHandlerBase): + async def get_contacts(self, lastmod=0) -> Event: + logger.debug("Getting contacts") + data = b"\x04" + if lastmod > 0: + data = data + lastmod.to_bytes(4, "little") + return await self.send(data, [EventType.CONTACTS, EventType.ERROR]) + + async def reset_path(self, key: DestinationType) -> Event: + key_bytes = _validate_destination(key, prefix_length=32) + logger.debug(f"Resetting path for contact: {key_bytes.hex()}") + data = b"\x0d" + key_bytes + return await self.send(data, [EventType.OK, EventType.ERROR]) + + async def share_contact(self, key: DestinationType) -> Event: + key_bytes = _validate_destination(key, prefix_length=32) + logger.debug(f"Sharing contact: {key_bytes.hex()}") + data = b"\x10" + key_bytes + return await self.send(data, [EventType.OK, EventType.ERROR]) + + async def export_contact(self, key: Optional[DestinationType] = None) -> Event: + if key: + key_bytes = _validate_destination(key, prefix_length=32) + logger.debug(f"Exporting contact: {key_bytes.hex()}") + data = b"\x11" + key_bytes + else: + logger.debug("Exporting node") + data = b"\x11" + return await self.send(data, [EventType.CONTACT_URI, EventType.ERROR]) + + async def import_contact(self, card_data) -> Event: + data = b"\x12" + card_data + return await self.send(data, [EventType.OK, EventType.ERROR]) + + async def remove_contact(self, key: DestinationType) -> Event: + key_bytes = _validate_destination(key, prefix_length=32) + logger.debug(f"Removing contact: {key_bytes.hex()}") + data = b"\x0f" + key_bytes + return await self.send(data, [EventType.OK, EventType.ERROR]) + + async def update_contact(self, contact, path=None, flags=None) -> Event: + if path is None: + out_path_hex = contact["out_path"] + out_path_len = contact["out_path_len"] + else: + out_path_hex = path + out_path_len = int(len(path) / 2) + # reflect the change + contact["out_path"] = out_path_hex + contact["out_path_len"] = out_path_len + out_path_hex = out_path_hex + (128 - len(out_path_hex)) * "0" + + if flags is None: + flags = contact["flags"] + else: + # reflect the change + contact["flags"] = flags + + adv_name_hex = contact["adv_name"].encode().hex() + adv_name_hex = adv_name_hex + (64 - len(adv_name_hex)) * "0" + data = ( + b"\x09" + + bytes.fromhex(contact["public_key"]) + + contact["type"].to_bytes(1) + + flags.to_bytes(1) + + out_path_len.to_bytes(1, "little", signed=True) + + bytes.fromhex(out_path_hex) + + bytes.fromhex(adv_name_hex) + + contact["last_advert"].to_bytes(4, "little") + + int(contact["adv_lat"] * 1e6).to_bytes(4, "little", signed=True) + + int(contact["adv_lon"] * 1e6).to_bytes(4, "little", signed=True) + ) + return await self.send(data, [EventType.OK, EventType.ERROR]) + + async def add_contact(self, contact) -> Event: + return await self.update_contact(contact) + + async def change_contact_path(self, contact, path) -> Event: + return await self.update_contact(contact, path) + + async def change_contact_flags(self, contact, flags) -> Event: + return await self.update_contact(contact, flags=flags) diff --git a/src/meshcore/commands/device.py b/src/meshcore/commands/device.py new file mode 100644 index 0000000..73ba033 --- /dev/null +++ b/src/meshcore/commands/device.py @@ -0,0 +1,200 @@ +import logging +from typing import Optional + +from ..events import Event, EventType +from .base import CommandHandlerBase, DestinationType, _validate_destination + +logger = logging.getLogger("meshcore") + + +class DeviceCommands(CommandHandlerBase): + async def send_appstart(self) -> Event: + logger.debug("Sending appstart command") + b1 = bytearray(b"\x01\x03 mccli") + return await self.send(b1, [EventType.SELF_INFO]) + + async def send_device_query(self) -> Event: + logger.debug("Sending device query command") + return await self.send(b"\x16\x03", [EventType.DEVICE_INFO, EventType.ERROR]) + + async def send_advert(self, flood: bool = False) -> Event: + logger.debug(f"Sending advertisement command (flood={flood})") + if flood: + return await self.send(b"\x07\x01", [EventType.OK, EventType.ERROR]) + else: + return await self.send(b"\x07", [EventType.OK, EventType.ERROR]) + + async def set_name(self, name: str) -> Event: + logger.debug(f"Setting device name to: {name}") + return await self.send( + b"\x08" + name.encode("utf-8"), [EventType.OK, EventType.ERROR] + ) + + async def set_coords(self, lat: float, lon: float) -> Event: + logger.debug(f"Setting coordinates to: lat={lat}, lon={lon}") + return await self.send( + b"\x0e" + + int(lat * 1e6).to_bytes(4, "little", signed=True) + + int(lon * 1e6).to_bytes(4, "little", signed=True) + + int(0).to_bytes(4, "little"), + [EventType.OK, EventType.ERROR], + ) + + async def reboot(self) -> Event: + logger.debug("Sending reboot command") + return await self.send(b"\x13reboot") + + async def get_bat(self) -> Event: + logger.debug("Getting battery information") + return await self.send(b"\x14", [EventType.BATTERY, EventType.ERROR]) + + async def get_time(self) -> Event: + logger.debug("Getting device time") + return await self.send(b"\x05", [EventType.CURRENT_TIME, EventType.ERROR]) + + async def set_time(self, val: int) -> Event: + logger.debug(f"Setting device time to: {val}") + return await self.send( + b"\x06" + int(val).to_bytes(4, "little"), [EventType.OK, EventType.ERROR] + ) + + async def set_tx_power(self, val: int) -> Event: + logger.debug(f"Setting TX power to: {val}") + return await self.send( + b"\x0c" + int(val).to_bytes(4, "little"), [EventType.OK, EventType.ERROR] + ) + + async def set_radio(self, freq: float, bw: float, sf: int, cr: int) -> Event: + logger.debug(f"Setting radio params: freq={freq}, bw={bw}, sf={sf}, cr={cr}") + return await self.send( + b"\x0b" + + int(float(freq) * 1000).to_bytes(4, "little") + + int(float(bw) * 1000).to_bytes(4, "little") + + int(sf).to_bytes(1, "little") + + int(cr).to_bytes(1, "little"), + [EventType.OK, EventType.ERROR], + ) + + async def set_tuning(self, rx_dly: int, af: int) -> Event: + logger.debug(f"Setting tuning params: rx_dly={rx_dly}, af={af}") + return await self.send( + b"\x15" + + int(rx_dly).to_bytes(4, "little") + + int(af).to_bytes(4, "little") + + int(0).to_bytes(1, "little") + + int(0).to_bytes(1, "little"), + [EventType.OK, EventType.ERROR], + ) + + async def set_other_params( + self, + manual_add_contacts: bool, + telemetry_mode_base: int, + telemetry_mode_loc: int, + telemetry_mode_env: int, + advert_loc_policy: int, + ) -> Event: + telemetry_mode = ( + (telemetry_mode_base & 0b11) + | ((telemetry_mode_loc & 0b11) << 2) + | ((telemetry_mode_env & 0b11) << 4) + ) + data = ( + b"\x26" + + manual_add_contacts.to_bytes(1) + + telemetry_mode.to_bytes(1) + + advert_loc_policy.to_bytes(1) + ) + return await self.send(data, [EventType.OK, EventType.ERROR]) + + async def set_telemetry_mode_base(self, telemetry_mode_base: int) -> Event: + infos = (await self.send_appstart()).payload + return await self.set_other_params( + infos["manual_add_contacts"], + telemetry_mode_base, + infos["telemetry_mode_loc"], + infos["telemetry_mode_env"], + infos["adv_loc_policy"], + ) + + async def set_telemetry_mode_loc(self, telemetry_mode_loc: int) -> Event: + infos = (await self.send_appstart()).payload + return await self.set_other_params( + infos["manual_add_contacts"], + infos["telemetry_mode_base"], + telemetry_mode_loc, + infos["telemetry_mode_env"], + infos["adv_loc_policy"], + ) + + async def set_telemetry_mode_env(self, telemetry_mode_env: int) -> Event: + infos = (await self.send_appstart()).payload + return await self.set_other_params( + infos["manual_add_contacts"], + infos["telemetry_mode_base"], + infos["telemetry_mode_loc"], + telemetry_mode_env, + infos["adv_loc_policy"], + ) + + async def set_manual_add_contacts(self, manual_add_contacts: bool) -> Event: + infos = (await self.send_appstart()).payload + return await self.set_other_params( + manual_add_contacts, + infos["telemetry_mode_base"], + infos["telemetry_mode_loc"], + infos["telemetry_mode_env"], + infos["adv_loc_policy"], + ) + + async def set_advert_loc_policy(self, advert_loc_policy: int) -> Event: + infos = (await self.send_appstart()).payload + return await self.set_other_params( + infos["manual_add_contacts"], + infos["telemetry_mode_base"], + infos["telemetry_mode_loc"], + infos["telemetry_mode_env"], + advert_loc_policy, + ) + + async def set_devicepin(self, pin: int) -> Event: + logger.debug(f"Setting device PIN to: {pin}") + return await self.send( + b"\x25" + int(pin).to_bytes(4, "little"), [EventType.OK, EventType.ERROR] + ) + + async def get_self_telemetry(self) -> Event: + logger.debug("Getting self telemetry") + data = b"\x27\x00\x00\x00" + return await self.send(data, [EventType.TELEMETRY_RESPONSE, EventType.ERROR]) + + async def get_custom_vars(self) -> Event: + logger.debug("Asking for custom vars") + data = b"\x28" + return await self.send(data, [EventType.CUSTOM_VARS, EventType.ERROR]) + + async def set_custom_var(self, key, value) -> Event: + logger.debug(f"Setting custom var {key} to {value}") + data = b"\x29" + key.encode("utf-8") + b":" + value.encode("utf-8") + return await self.send(data, [EventType.OK, EventType.ERROR]) + + async def get_channel(self, channel_idx: int) -> Event: + logger.debug(f"Getting channel info for channel {channel_idx}") + data = b"\x1f" + channel_idx.to_bytes(1, "little") + return await self.send(data, [EventType.CHANNEL_INFO, EventType.ERROR]) + + async def set_channel( + self, channel_idx: int, channel_name: str, channel_secret: bytes + ) -> Event: + logger.debug(f"Setting channel {channel_idx}: name={channel_name}") + + # Pad channel name to 32 bytes + name_bytes = channel_name.encode("utf-8")[:32] + name_bytes = name_bytes.ljust(32, b"\x00") + + # Ensure channel secret is exactly 16 bytes + if len(channel_secret) != 16: + raise ValueError("Channel secret must be exactly 16 bytes") + + data = b"\x20" + channel_idx.to_bytes(1, "little") + name_bytes + channel_secret + return await self.send(data, [EventType.OK, EventType.ERROR]) diff --git a/src/meshcore/commands/messaging.py b/src/meshcore/commands/messaging.py new file mode 100644 index 0000000..898e78f --- /dev/null +++ b/src/meshcore/commands/messaging.py @@ -0,0 +1,167 @@ +import logging +import random +from typing import Optional, Union + +from ..events import Event, EventType +from .base import CommandHandlerBase, DestinationType, _validate_destination + +logger = logging.getLogger("meshcore") + + +class MessagingCommands(CommandHandlerBase): + async def get_msg(self, timeout: Optional[float] = None) -> Event: + logger.debug("Requesting pending messages") + return await self.send( + b"\x0a", + [ + EventType.CONTACT_MSG_RECV, + EventType.CHANNEL_MSG_RECV, + EventType.ERROR, + EventType.NO_MORE_MSGS, + ], + timeout, + ) + + async def send_login(self, dst: DestinationType, pwd: str) -> Event: + dst_bytes = _validate_destination(dst, prefix_length=32) + logger.debug(f"Sending login request to: {dst_bytes.hex()}") + data = b"\x1a" + dst_bytes + pwd.encode("utf-8") + return await self.send(data, [EventType.MSG_SENT, EventType.ERROR]) + + async def send_logout(self, dst: DestinationType) -> Event: + dst_bytes = _validate_destination(dst, prefix_length=32) + data = b"\x1d" + dst_bytes + return await self.send(data, [EventType.OK, EventType.ERROR]) + + async def send_statusreq(self, dst: DestinationType) -> Event: + dst_bytes = _validate_destination(dst, prefix_length=32) + logger.debug(f"Sending status request to: {dst_bytes.hex()}") + data = b"\x1b" + dst_bytes + return await self.send(data, [EventType.MSG_SENT, EventType.ERROR]) + + async def send_cmd( + self, dst: DestinationType, cmd: str, timestamp: Optional[int] = None + ) -> Event: + dst_bytes = _validate_destination(dst) + logger.debug(f"Sending command to {dst_bytes.hex()}: {cmd}") + + if timestamp is None: + import time + + timestamp = int(time.time()) + + data = ( + b"\x02\x01\x00" + + timestamp.to_bytes(4, "little") + + dst_bytes + + cmd.encode("utf-8") + ) + return await self.send(data, [EventType.MSG_SENT, EventType.ERROR]) + + async def send_msg( + self, dst: DestinationType, msg: str, timestamp: Optional[int] = None + ) -> Event: + dst_bytes = _validate_destination(dst) + logger.debug(f"Sending message to {dst_bytes.hex()}: {msg}") + + if timestamp is None: + import time + + timestamp = int(time.time()) + + data = ( + b"\x02\x00\x00" + + timestamp.to_bytes(4, "little") + + dst_bytes + + msg.encode("utf-8") + ) + return await self.send(data, [EventType.MSG_SENT, EventType.ERROR]) + + async def send_chan_msg(self, chan, msg, timestamp=None) -> Event: + logger.debug(f"Sending channel message to channel {chan}: {msg}") + + # Default to current time if timestamp not provided + if timestamp is None: + import time + + timestamp = int(time.time()).to_bytes(4, "little") + + data = ( + b"\x03\x00" + chan.to_bytes(1, "little") + timestamp + msg.encode("utf-8") + ) + return await self.send(data, [EventType.OK, EventType.ERROR]) + + async def send_telemetry_req(self, dst: DestinationType) -> Event: + dst_bytes = _validate_destination(dst, prefix_length=32) + logger.debug(f"Asking telemetry to {dst_bytes.hex()}") + data = b"\x27\x00\x00\x00" + dst_bytes + return await self.send(data, [EventType.MSG_SENT, EventType.ERROR]) + + async def send_binary_req(self, dst: DestinationType, bin_data) -> Event: + dst_bytes = _validate_destination(dst, prefix_length=32) + logger.debug(f"Binary request to {dst_bytes.hex()}") + data = b"\x32" + dst_bytes + bin_data + return await self.send(data, [EventType.MSG_SENT, EventType.ERROR]) + + async def send_path_discovery(self, dst: DestinationType) -> Event: + dst_bytes = _validate_destination(dst, prefix_length=32) + logger.debug(f"Path discovery request for {dst_bytes.hex()}") + data = b"\x34\x00" + dst_bytes + return await self.send(data, [EventType.MSG_SENT, EventType.ERROR]) + + async def send_trace( + self, + auth_code: int = 0, + tag: Optional[int] = None, + flags: int = 0, + path: Optional[Union[str, bytes, bytearray]] = None, + ) -> Event: + """ + Send a trace packet to test routing through specific repeaters + + Args: + auth_code: 32-bit authentication code (default: 0) + tag: 32-bit integer to identify this trace (default: random) + flags: 8-bit flags field (default: 0) + path: Optional string with comma-separated hex values representing repeater pubkeys (e.g. "23,5f,3a") + or a bytes/bytearray object with the raw path data + + Returns: + Event object with sent status, tag, and estimated timeout in milliseconds + """ + # Generate random tag if not provided + if tag is None: + tag = random.randint(1, 0xFFFFFFFF) + if auth_code is None: + auth_code = random.randint(1, 0xFFFFFFFF) + + logger.debug( + f"Sending trace: tag={tag}, auth={auth_code}, flags={flags}, path={path}" + ) + + # Prepare the command packet: CMD(1) + tag(4) + auth_code(4) + flags(1) + [path] + cmd_data = bytearray([36]) # CMD_SEND_TRACE_PATH + cmd_data.extend(tag.to_bytes(4, "little")) + cmd_data.extend(auth_code.to_bytes(4, "little")) + cmd_data.append(flags) + + # Process path if provided + if path: + if isinstance(path, str): + # Convert comma-separated hex values to bytes + try: + path_bytes = bytearray() + for hex_val in path.split(","): + hex_val = hex_val.strip() + path_bytes.append(int(hex_val, 16)) + cmd_data.extend(path_bytes) + except ValueError as e: + logger.error(f"Invalid path format: {e}") + return Event(EventType.ERROR, {"reason": "invalid_path_format"}) + elif isinstance(path, (bytes, bytearray)): + cmd_data.extend(path) + else: + logger.error(f"Unsupported path type: {type(path)}") + return Event(EventType.ERROR, {"reason": "unsupported_path_type"}) + + return await self.send(cmd_data, [EventType.MSG_SENT, EventType.ERROR]) diff --git a/src/meshcore/connection_manager.py b/src/meshcore/connection_manager.py index fd5984f..c95ec37 100644 --- a/src/meshcore/connection_manager.py +++ b/src/meshcore/connection_manager.py @@ -1,6 +1,7 @@ """ Connection manager that orchestrates reconnection logic for any connection type. """ + import asyncio import logging from typing import Optional, Any, Callable, Protocol @@ -8,21 +9,22 @@ from .events import Event, EventType logger = logging.getLogger("meshcore") + class ConnectionProtocol(Protocol): """Protocol defining the interface that connection classes must implement.""" - + async def connect(self) -> Optional[Any]: """Connect and return connection info, or None if failed.""" ... - + async def disconnect(self): """Disconnect from the device/server.""" ... - + async def send(self, data): """Send data through the connection.""" ... - + def set_reader(self, reader): """Set the message reader.""" ... @@ -30,27 +32,32 @@ class ConnectionProtocol(Protocol): class ConnectionManager: """Manages connection lifecycle with auto-reconnect and event emission.""" - - def __init__(self, connection: ConnectionProtocol, event_dispatcher=None, - auto_reconnect: bool = False, max_reconnect_attempts: int = 3): + + def __init__( + self, + connection: ConnectionProtocol, + event_dispatcher=None, + auto_reconnect: bool = False, + max_reconnect_attempts: int = 3, + ): self.connection = connection self.event_dispatcher = event_dispatcher self.auto_reconnect = auto_reconnect self.max_reconnect_attempts = max_reconnect_attempts - + self._reconnect_attempts = 0 self._is_connected = False self._reconnect_task = None self._disconnect_callback: Optional[Callable] = None - + def set_disconnect_callback(self, callback: Callable): """Set a callback to be called when disconnection is detected.""" self._disconnect_callback = callback - + async def connect(self) -> Optional[Any]: """Connect with event handling and state management.""" result = await self.connection.connect() - + if result is not None: self._is_connected = True self._reconnect_attempts = 0 @@ -58,9 +65,9 @@ class ConnectionManager: logger.debug(f"Connected successfully: {result}") else: logger.debug("Connection failed") - + return result - + async def disconnect(self): """Disconnect with proper cleanup.""" if self._reconnect_task: @@ -70,80 +77,93 @@ class ConnectionManager: except asyncio.CancelledError: pass self._reconnect_task = None - + if self._is_connected: await self.connection.disconnect() self._is_connected = False - await self._emit_event(EventType.DISCONNECTED, {"reason": "manual_disconnect"}) - + await self._emit_event( + EventType.DISCONNECTED, {"reason": "manual_disconnect"} + ) + async def handle_disconnect(self, reason: str = "unknown"): """Handle unexpected disconnections with optional auto-reconnect.""" if not self._is_connected: return - + self._is_connected = False logger.debug(f"Connection lost: {reason}") - - if self.auto_reconnect and self._reconnect_attempts < self.max_reconnect_attempts: + + if ( + self.auto_reconnect + and self._reconnect_attempts < self.max_reconnect_attempts + ): self._reconnect_task = asyncio.create_task(self._attempt_reconnect()) else: - await self._emit_event(EventType.DISCONNECTED, { - "reason": reason, - "reconnect_failed": self._reconnect_attempts >= self.max_reconnect_attempts - }) - + await self._emit_event( + EventType.DISCONNECTED, + { + "reason": reason, + "reconnect_failed": self._reconnect_attempts + >= self.max_reconnect_attempts, + }, + ) + async def _attempt_reconnect(self): """Attempt to reconnect with flat delay.""" - logger.debug(f"Attempting reconnection ({self._reconnect_attempts + 1}/{self.max_reconnect_attempts})") + logger.debug( + f"Attempting reconnection ({self._reconnect_attempts + 1}/{self.max_reconnect_attempts})" + ) self._reconnect_attempts += 1 - + # Flat 1 second delay for all attempts await asyncio.sleep(1) - + try: result = await self.connection.connect() if result is not None: self._is_connected = True self._reconnect_attempts = 0 - await self._emit_event(EventType.CONNECTED, { - "connection_info": result, - "reconnected": True - }) - logger.debug(f"Reconnected successfully") + await self._emit_event( + EventType.CONNECTED, + {"connection_info": result, "reconnected": True}, + ) + logger.debug("Reconnected successfully") else: # Reconnection failed, try again if we haven't exceeded max attempts if self._reconnect_attempts < self.max_reconnect_attempts: - self._reconnect_task = asyncio.create_task(self._attempt_reconnect()) + self._reconnect_task = asyncio.create_task( + self._attempt_reconnect() + ) else: - await self._emit_event(EventType.DISCONNECTED, { - "reason": "reconnect_failed", - "max_attempts_exceeded": True - }) + await self._emit_event( + EventType.DISCONNECTED, + {"reason": "reconnect_failed", "max_attempts_exceeded": True}, + ) except Exception as e: logger.debug(f"Reconnection attempt failed: {e}") if self._reconnect_attempts < self.max_reconnect_attempts: self._reconnect_task = asyncio.create_task(self._attempt_reconnect()) else: - await self._emit_event(EventType.DISCONNECTED, { - "reason": f"reconnect_error: {e}", - "max_attempts_exceeded": True - }) - + await self._emit_event( + EventType.DISCONNECTED, + {"reason": f"reconnect_error: {e}", "max_attempts_exceeded": True}, + ) + async def _emit_event(self, event_type: EventType, payload: dict): """Emit connection events if dispatcher is available.""" if self.event_dispatcher: event = Event(event_type, payload) await self.event_dispatcher.dispatch(event) - + @property def is_connected(self) -> bool: """Check if the connection is active.""" return self._is_connected - + async def send(self, data): """Send data through the managed connection.""" return await self.connection.send(data) - + def set_reader(self, reader): """Set the message reader on the underlying connection.""" - self.connection.set_reader(reader) \ No newline at end of file + self.connection.set_reader(reader) diff --git a/src/meshcore/events.py b/src/meshcore/events.py index fc81f80..c21d679 100644 --- a/src/meshcore/events.py +++ b/src/meshcore/events.py @@ -1,13 +1,13 @@ from enum import Enum import logging -from math import log from typing import Any, Dict, Optional, Callable, List, Union import asyncio from dataclasses import dataclass, field logger = logging.getLogger("meshcore") -# Public event types for users to subscribe to + +# Public event types for users to subscribe to class EventType(Enum): CONTACTS = "contacts" SELF_INFO = "self_info" @@ -20,7 +20,7 @@ class EventType(Enum): DEVICE_INFO = "device_info" MSG_SENT = "message_sent" NEW_CONTACT = "new_contact" - + # Push notifications ADVERTISEMENT = "advertisement" PATH_UPDATE = "path_update" @@ -28,7 +28,7 @@ class EventType(Enum): MESSAGES_WAITING = "messages_waiting" RAW_DATA = "raw_data" LOGIN_SUCCESS = "login_success" - LOGIN_FAILED = "login_failed" + LOGIN_FAILED = "login_failed" STATUS_RESPONSE = "status_response" LOG_DATA = "log_data" TRACE_DATA = "trace_data" @@ -38,11 +38,11 @@ class EventType(Enum): CUSTOM_VARS = "custom_vars" CHANNEL_INFO = "channel_info" PATH_RESPONSE = "path_response" - + # Command response types OK = "command_ok" ERROR = "command_error" - + # Connection events CONNECTED = "connected" DISCONNECTED = "disconnected" @@ -53,11 +53,17 @@ class Event: type: EventType payload: Any attributes: Dict[str, Any] = field(default_factory=dict) - - def __init__(self, type: EventType, payload: Any, attributes: Optional[Dict[str, Any]] = None, **kwargs): + + def __init__( + self, + type: EventType, + payload: Any, + attributes: Optional[Dict[str, Any]] = None, + **kwargs, + ): """ Initialize an Event - + Args: type: The event type payload: The event payload @@ -67,18 +73,21 @@ class Event: self.type = type self.payload = payload self.attributes = attributes or {} - + # Add any keyword arguments to the attributes dictionary if kwargs: self.attributes.update(kwargs) + def clone(self): """ Create a copy of the event. - + Returns: A new Event object with the same type, payload, and attributes. """ - copied_payload = self.payload.copy() if isinstance(self.payload, dict) else self.payload + copied_payload = ( + self.payload.copy() if isinstance(self.payload, dict) else self.payload + ) return Event(self.type, copied_payload, self.attributes.copy()) @@ -88,7 +97,7 @@ class Subscription: self.event_type = event_type self.callback = callback self.attribute_filters = attribute_filters or {} - + def unsubscribe(self): self.dispatcher._remove_subscription(self) @@ -99,12 +108,16 @@ class EventDispatcher: self.subscriptions: List[Subscription] = [] self.running = False self._task = None - - def subscribe(self, event_type: Union[EventType, None], callback: Callable[[Event], Union[None, asyncio.Future]], - attribute_filters: Optional[Dict[str, Any]] = None) -> Subscription: + + def subscribe( + self, + event_type: Union[EventType, None], + callback: Callable[[Event], Union[None, asyncio.Future]], + attribute_filters: Optional[Dict[str, Any]] = None, + ) -> Subscription: """ Subscribe to events with optional attribute filtering. - + Parameters: ----------- event_type : EventType or None @@ -113,7 +126,7 @@ class EventDispatcher: Function to call when a matching event is received. attribute_filters : Dict[str, Any], optional Dictionary of attribute key-value pairs that must match for the event to trigger the callback. - + Returns: -------- Subscription object that can be used to unsubscribe. @@ -121,26 +134,36 @@ class EventDispatcher: subscription = Subscription(self, event_type, callback, attribute_filters) self.subscriptions.append(subscription) return subscription - + def _remove_subscription(self, subscription: Subscription): if subscription in self.subscriptions: self.subscriptions.remove(subscription) - + async def dispatch(self, event: Event): await self.queue.put(event) - + async def _process_events(self): while self.running: event = await self.queue.get() - logger.debug(f"Dispatching event: {event.type}, {event.payload}, {event.attributes}") + logger.debug( + f"Dispatching event: {event.type}, {event.payload}, {event.attributes}" + ) for subscription in self.subscriptions.copy(): # Check if event type matches - if subscription.event_type is None or subscription.event_type == event.type: + if ( + subscription.event_type is None + or subscription.event_type == event.type + ): # Check if all attribute filters match - if subscription.attribute_filters and subscription.attribute_filters != {}: + if ( + subscription.attribute_filters + and subscription.attribute_filters != {} + ): # Skip if any filter doesn't match the corresponding event attribute - if not all(event.attributes.get(key) == value - for key, value in subscription.attribute_filters.items()): + if not all( + event.attributes.get(key) == value + for key, value in subscription.attribute_filters.items() + ): continue try: result = subscription.callback(event.clone()) @@ -148,14 +171,14 @@ class EventDispatcher: await result except Exception as e: print(f"Error in event handler: {e}") - + self.queue.task_done() - + async def start(self): if not self.running: self.running = True self._task = asyncio.create_task(self._process_events()) - + async def stop(self): if self.running: self.running = False @@ -167,12 +190,16 @@ class EventDispatcher: except asyncio.CancelledError: pass self._task = None - - async def wait_for_event(self, event_type: EventType, attribute_filters: Optional[Dict[str, Any]] = None, - timeout: float | None = None) -> Optional[Event]: + + async def wait_for_event( + self, + event_type: EventType, + attribute_filters: Optional[Dict[str, Any]] = None, + timeout: float | None = None, + ) -> Optional[Event]: """ Wait for an event of the specified type that matches all attribute filters. - + Parameters: ----------- event_type : EventType @@ -181,19 +208,19 @@ class EventDispatcher: Dictionary of attribute key-value pairs that must match for the event to be returned. timeout : float | None, optional Maximum time to wait for the event, in seconds. - + Returns: -------- The matched event, or None if timeout occurred before a matching event. """ future = asyncio.Future() - + def event_handler(event: Event): if not future.done(): future.set_result(event) - + subscription = self.subscribe(event_type, event_handler, attribute_filters) - + try: return await asyncio.wait_for(future, timeout) except asyncio.TimeoutError: diff --git a/src/meshcore/lpp_json_encoder.py b/src/meshcore/lpp_json_encoder.py index e52c57f..47634a9 100644 --- a/src/meshcore/lpp_json_encoder.py +++ b/src/meshcore/lpp_json_encoder.py @@ -2,60 +2,63 @@ from cayennelpp import LppFrame, LppData from cayennelpp.lpp_type import LppType # Format : type name "how to display value" -# display : None: (use lib default), []: only one value to display, ["field1", "field2" ...]: meaning of each field +# display : None: (use lib default), []: only one value to display, ["field1", "field2" ...]: meaning of each field my_lpp_types = { - 0: ('digital input', []), - 1: ('digital output', []), - 2: ('analog input', []), - 3: ('analog output', []), - 100: ('generic sensor', []), - 101: ('illuminance', []), - 102: ('presence', []), - 103: ('temperature', []), - 104: ('humidity', []), - 113: ('accelerometer', ["acc_x", "acc_y", "acc_z"]), - 115: ('barometer', []), - 116: ('voltage', []), - 117: ('current', []), - 118: ('frequency', []), - 120: ('percentage', []), - 121: ('altitude', []), - 122: ('load', []), - 125: ('concentration', []), - 128: ('power', []), - 130: ('distance', []), - 131: ('energy', []), - 132: ('direction', None), - 133: ('time', []), - 134: ('gyrometer', None), - 135: ('colour', ["red", "green", "blue"]), - 136: ('gps', ["latitude", "longitude", "altitude"]), - 142: ('switch', []), + 0: ("digital input", []), + 1: ("digital output", []), + 2: ("analog input", []), + 3: ("analog output", []), + 100: ("generic sensor", []), + 101: ("illuminance", []), + 102: ("presence", []), + 103: ("temperature", []), + 104: ("humidity", []), + 113: ("accelerometer", ["acc_x", "acc_y", "acc_z"]), + 115: ("barometer", []), + 116: ("voltage", []), + 117: ("current", []), + 118: ("frequency", []), + 120: ("percentage", []), + 121: ("altitude", []), + 122: ("load", []), + 125: ("concentration", []), + 128: ("power", []), + 130: ("distance", []), + 131: ("energy", []), + 132: ("direction", None), + 133: ("time", []), + 134: ("gyrometer", None), + 135: ("colour", ["red", "green", "blue"]), + 136: ("gps", ["latitude", "longitude", "altitude"]), + 142: ("switch", []), } + def lpp_format_val(type, val): - if my_lpp_types[type.type][1] is None : + if my_lpp_types[type.type][1] is None: return val - if len(my_lpp_types[type.type][1]) == 0 : + if len(my_lpp_types[type.type][1]) == 0: return val[0] val_dict = {} i = 0 - for t in my_lpp_types[type.type][1] : + for t in my_lpp_types[type.type][1]: val_dict[t] = val[i] i = i + 1 return val_dict -def lpp_json_encoder (obj, types = my_lpp_types) : + +def lpp_json_encoder(obj, types=my_lpp_types): """Encode LppType, LppData, and LppFrame to JSON.""" if isinstance(obj, LppFrame): return obj.data if isinstance(obj, LppType): return my_lpp_types[obj.type][0] if isinstance(obj, LppData): - return {"channel" : obj.channel, - "type" : obj.type, - "value" : lpp_format_val(obj.type, obj.value) + return { + "channel": obj.channel, + "type": obj.type, + "value": lpp_format_val(obj.type, obj.value), } raise TypeError(repr(obj) + " is not JSON serialized") diff --git a/src/meshcore/meshcore.py b/src/meshcore/meshcore.py index 699cc92..5b21871 100644 --- a/src/meshcore/meshcore.py +++ b/src/meshcore/meshcore.py @@ -1,8 +1,8 @@ import asyncio import logging -from typing import Optional, Dict, Any, Union +from typing import Any, Callable, Coroutine, Dict, Optional, Union -from .events import EventDispatcher, EventType +from .events import Event, EventDispatcher, EventType, Subscription from .reader import MessageReader from .commands import CommandHandler from .connection_manager import ConnectionManager @@ -13,21 +13,31 @@ from .serial_cx import SerialConnection # Setup default logger logger = logging.getLogger("meshcore") + class MeshCore: """ Interface to a MeshCore device """ - def __init__(self, cx, debug=False, only_error=False, default_timeout=None, auto_reconnect=False, max_reconnect_attempts=3): + + def __init__( + self, + cx: Union[BLEConnection, TCPConnection, SerialConnection], + debug: bool = False, + only_error: bool = False, + default_timeout: Optional[float] = None, + auto_reconnect: bool = False, + max_reconnect_attempts: int = 3, + ): # Wrap connection with ConnectionManager self.dispatcher = EventDispatcher() self.connection_manager = ConnectionManager( cx, self.dispatcher, auto_reconnect, max_reconnect_attempts ) self.cx = self.connection_manager # For backward compatibility - + self._reader = MessageReader(self.dispatcher) self.commands = CommandHandler(default_timeout=default_timeout) - + # Set up logger if debug: logger.setLevel(logging.DEBUG) @@ -35,14 +45,14 @@ class MeshCore: logger.setLevel(logging.ERROR) else: logger.setLevel(logging.INFO) - + # Set up connections self.commands.set_connection(self.connection_manager) - + # Set the dispatcher in the command handler self.commands.set_dispatcher(self.dispatcher) self.commands.set_reader(self._reader) - + # Initialize state (private) self._contacts = {} self._contacts_dirty = True @@ -51,40 +61,77 @@ class MeshCore: self._time = 0 self._lastmod = 0 self._auto_update_contacts = False - + # Set up event subscriptions to track data self._setup_data_tracking() - + self.connection_manager.set_reader(self._reader) - + # Set up disconnect callback cx.set_disconnect_callback(self.connection_manager.handle_disconnect) - + @classmethod - async def create_tcp(cls, host: str, port: int, debug: bool = False, only_error:bool = False, default_timeout=None, - auto_reconnect: bool = False, max_reconnect_attempts: int = 3) -> 'MeshCore': - """Create and connect a MeshCore instance using TCP connection""" + async def create_tcp( + cls, + host: str, + port: int, + debug: bool = False, + only_error: bool = False, + default_timeout=None, + auto_reconnect: bool = False, + max_reconnect_attempts: int = 3, + ) -> "MeshCore": + """Create and connect a MeshCore instance using TCP connection""" connection = TCPConnection(host, port) - - mc = cls(connection, debug=debug, only_error=only_error, default_timeout=default_timeout, - auto_reconnect=auto_reconnect, max_reconnect_attempts=max_reconnect_attempts) + + mc = cls( + connection, + debug=debug, + only_error=only_error, + default_timeout=default_timeout, + auto_reconnect=auto_reconnect, + max_reconnect_attempts=max_reconnect_attempts, + ) await mc.connect() return mc - + @classmethod - async def create_serial(cls, port: str, baudrate: int = 115200, debug: bool = False, only_error:bool=False, default_timeout=None, - auto_reconnect: bool = False, max_reconnect_attempts: int = 3, cx_dly:float = 0.1) -> 'MeshCore': + async def create_serial( + cls, + port: str, + baudrate: int = 115200, + debug: bool = False, + only_error: bool = False, + default_timeout=None, + auto_reconnect: bool = False, + max_reconnect_attempts: int = 3, + cx_dly: float = 0.1, + ) -> "MeshCore": """Create and connect a MeshCore instance using serial connection""" connection = SerialConnection(port, baudrate, cx_dly=cx_dly) - - mc = cls(connection, debug=debug, only_error=only_error, default_timeout=default_timeout, - auto_reconnect=auto_reconnect, max_reconnect_attempts=max_reconnect_attempts) + + mc = cls( + connection, + debug=debug, + only_error=only_error, + default_timeout=default_timeout, + auto_reconnect=auto_reconnect, + max_reconnect_attempts=max_reconnect_attempts, + ) await mc.connect() return mc - + @classmethod - async def create_ble(cls, address: Optional[str] = None, client=None, debug: bool = False, only_error:bool=False, default_timeout=None, - auto_reconnect: bool = False, max_reconnect_attempts: int = 3) -> 'MeshCore': + async def create_ble( + cls, + address: Optional[str] = None, + client=None, + debug: bool = False, + only_error: bool = False, + default_timeout=None, + auto_reconnect: bool = False, + max_reconnect_attempts: int = 3, + ) -> "MeshCore": """ Create and connect a MeshCore instance using BLE connection. @@ -94,87 +141,103 @@ class MeshCore: If provided, 'address' is ignored for connection but can be used for identification. """ - + connection = BLEConnection(address=address, client=client) - - mc = cls(connection, debug=debug, only_error=only_error, default_timeout=default_timeout, - auto_reconnect=auto_reconnect, max_reconnect_attempts=max_reconnect_attempts) + + mc = cls( + connection, + debug=debug, + only_error=only_error, + default_timeout=default_timeout, + auto_reconnect=auto_reconnect, + max_reconnect_attempts=max_reconnect_attempts, + ) await mc.connect() return mc - + async def connect(self): await self.dispatcher.start() result = await self.connection_manager.connect() if result is None: raise ConnectionError("Failed to connect to device") return await self.commands.send_appstart() - + async def disconnect(self): """Disconnect from the device and clean up resources.""" # First stop the dispatcher to prevent any new events await self.dispatcher.stop() - + # Stop auto message fetching if it's running - if hasattr(self, '_auto_fetch_subscription') and self._auto_fetch_subscription: + if hasattr(self, "_auto_fetch_subscription") and self._auto_fetch_subscription: await self.stop_auto_message_fetching() - + # Disconnect the connection object await self.connection_manager.disconnect() - + def stop(self): """Synchronously stop the event dispatcher task""" if self.dispatcher._task and not self.dispatcher._task.done(): self.dispatcher.running = False self.dispatcher._task.cancel() - - def subscribe(self, event_type: Union[EventType, None], callback, attribute_filters: Optional[Dict[str, Any]] = None): + + def subscribe( + self, + event_type: Union[EventType, None], + callback: Callable[[Event], Coroutine[Any, Any, None]], + attribute_filters: Optional[Dict[str, Any]] = None, + ) -> Subscription: """ Subscribe to events using EventType enum with optional attribute filtering - + Args: event_type: Type of event to subscribe to, from EventType enum callback: Async function to call when event occurs attribute_filters: Dictionary of attribute key-value pairs that must match for the event to trigger the callback - + Returns: Subscription object that can be used to unsubscribe - + Example: # Subscribe to ACK events where the 'code' attribute has a specific value mc.subscribe( - EventType.ACK, + EventType.ACK, my_callback_function, attribute_filters={'code': 'SUCCESS'} ) """ return self.dispatcher.subscribe(event_type, callback, attribute_filters) - - def unsubscribe(self, subscription): + + def unsubscribe(self, subscription: Subscription) -> None: """ Unsubscribe from events using a subscription object - + Args: subscription: Subscription object returned from subscribe() """ if subscription: subscription.unsubscribe() - - async def wait_for_event(self, event_type: EventType, attribute_filters: Optional[Dict[str, Any]] = None, timeout=None): + + async def wait_for_event( + self, + event_type: EventType, + attribute_filters: Optional[Dict[str, Any]] = None, + timeout: Optional[float] = None, + ) -> Optional[Event]: """ Wait for an event using EventType enum with optional attribute filtering - + Args: event_type: Type of event to wait for, from EventType enum attribute_filters: Dictionary of attribute key-value pairs to match against the event timeout: Maximum time to wait in seconds, or None to use default_timeout - + Returns: Event object or None if timeout - + Example: # Wait for an ACK event where the 'code' attribute has a specific value await mc.wait_for_event( - EventType.ACK, + EventType.ACK, attribute_filters={'code': 'SUCCESS'}, timeout=30.0 ) @@ -182,22 +245,25 @@ class MeshCore: # Use the provided timeout or fall back to default_timeout if timeout is None: timeout = self.default_timeout - - return await self.dispatcher.wait_for_event(event_type, attribute_filters, timeout) - + + return await self.dispatcher.wait_for_event( + event_type, attribute_filters, timeout + ) + def _setup_data_tracking(self): """Set up event subscriptions to track data internally""" + async def _update_contacts(event): - #self._contacts.update(event.payload) + # self._contacts.update(event.payload) for c in event.payload.values(): if c["public_key"] in self._contacts: self._contacts[c["public_key"]].update(c) else: - self._contacts[c["public_key"]]=c - if "lastmod" in event.attributes : - self._lastmod = event.attributes['lastmod'] + self._contacts[c["public_key"]] = c + if "lastmod" in event.attributes: + self._lastmod = event.attributes["lastmod"] self._contacts_dirty = False - + async def _add_pending_contact(event): c = event.payload self._pending_contacts[c["public_key"]] = c @@ -206,13 +272,13 @@ class MeshCore: self._contacts_dirty = True if self._auto_update_contacts: await self.ensure_contacts(follow=True) - + async def _update_self_info(event): self._self_info = event.payload - + async def _update_time(event): self._time = event.payload.get("time", 0) - + # Subscribe to events to update internal state self.subscribe(EventType.CONTACTS, _update_contacts) self.subscribe(EventType.NEW_CONTACT, _add_pending_contact) @@ -220,166 +286,177 @@ class MeshCore: self.subscribe(EventType.CURRENT_TIME, _update_time) self.subscribe(EventType.ADVERTISEMENT, _contact_change) self.subscribe(EventType.PATH_UPDATE, _contact_change) - + # Getter methods for state @property - def contacts(self): + def contacts(self) -> Dict[str, Any]: """Get the current contacts""" return self._contacts @property - def contacts_dirty(self): + def contacts_dirty(self) -> bool: """Get wether contact list is in sync""" return self._contacts_dirty - + @property - def auto_update_contacts(self): + def auto_update_contacts(self) -> bool: """Get wether contact list is in sync""" return self._auto_update_contacts @auto_update_contacts.setter - def auto_update_contacts(self, value): + def auto_update_contacts(self, value: bool) -> None: self._auto_update_contacts = value @property - def self_info(self): + def self_info(self) -> Dict[str, Any]: """Get device self info""" return self._self_info - + @property - def time(self): + def time(self) -> int: """Get the current device time""" return self._time - + @property - def is_connected(self): + def is_connected(self) -> bool: """Check if the connection is active""" return self.connection_manager.is_connected - + @property - def default_timeout(self): + def default_timeout(self) -> float: """Get the default timeout for commands""" return self.commands.default_timeout - + @default_timeout.setter - def default_timeout(self, value): + def default_timeout(self, value: float) -> None: """Set the default timeout for commands""" self.commands.default_timeout = value - + @property - def pending_contacts(self): + def pending_contacts(self) -> Dict[str, Any]: """Get pending contacts""" return self._pending_contacts - - def pop_pending_contact(self, key): + + def pop_pending_contact(self, key: str) -> Optional[Dict[str, Any]]: return self._pending_contacts.pop(key, None) - def flush_pending_contacts(self): # would be interesting to have a time param + def flush_pending_contacts(self) -> None: # would be interesting to have a time param self._pending_contacts = {} - def get_contact_by_name(self, name) -> Optional[Dict[str, Any]]: + def get_contact_by_name(self, name: str) -> Optional[Dict[str, Any]]: """ Find a contact by its name (adv_name field) - + Args: name: The name to search for - + Returns: Contact dictionary or None if not found """ if not self._contacts: return None - + for _, contact in self._contacts.items(): if contact.get("adv_name", "").lower() == name.lower(): return contact - + return None - - def get_contact_by_key_prefix(self, prefix) -> Optional[Dict[str, Any]]: + + def get_contact_by_key_prefix(self, prefix: str) -> Optional[Dict[str, Any]]: """ Find a contact by its public key prefix - + Args: prefix: The public key prefix to search for (can be a partial prefix) - + Returns: Contact dictionary or None if not found """ if not self._contacts or not prefix: return None - + # Convert the prefix to lowercase for case-insensitive matching prefix = prefix.lower() - + for contact_id, contact in self._contacts.items(): public_key = contact.get("public_key", "").lower() if public_key.startswith(prefix): return contact - + return None - - async def start_auto_message_fetching(self): + + async def start_auto_message_fetching(self) -> Subscription: """ Start automatically fetching messages when messages_waiting events are received. - This will continuously check for new messages when the device indicates + This will continuously check for new messages when the device indicates messages are waiting. """ self._auto_fetch_task = None self._auto_fetch_running = True - + async def _handle_messages_waiting(event): # Only start a new fetch task if one isn't already running if not self._auto_fetch_task or self._auto_fetch_task.done(): self._auto_fetch_task = asyncio.create_task(_fetch_messages_loop()) - + async def _fetch_messages_loop(): while self._auto_fetch_running: try: # Request the next message result = await self.commands.get_msg() - + # If we got a NO_MORE_MSGS event or an error, stop fetching - if result.type == EventType.NO_MORE_MSGS or result.type == EventType.ERROR: - logger.debug("No more messages or error occurred, stopping auto-fetch.") + if ( + result.type == EventType.NO_MORE_MSGS + or result.type == EventType.ERROR + ): + logger.debug( + "No more messages or error occurred, stopping auto-fetch." + ) break - + # Small delay to prevent overwhelming the device await asyncio.sleep(0.1) except Exception as e: logger.error(f"Error fetching messages: {e}") break - + # Subscribe to MESSAGES_WAITING events - self._auto_fetch_subscription = self.subscribe(EventType.MESSAGES_WAITING, _handle_messages_waiting) - + self._auto_fetch_subscription = self.subscribe( + EventType.MESSAGES_WAITING, _handle_messages_waiting + ) + # Check for any pending messages immediately await self.commands.get_msg() - + return self._auto_fetch_subscription - + async def stop_auto_message_fetching(self): """ Stop automatically fetching messages when messages_waiting events are received. """ - if hasattr(self, '_auto_fetch_subscription') and self._auto_fetch_subscription: + if hasattr(self, "_auto_fetch_subscription") and self._auto_fetch_subscription: self.unsubscribe(self._auto_fetch_subscription) self._auto_fetch_subscription = None - - if hasattr(self, '_auto_fetch_running'): + + if hasattr(self, "_auto_fetch_running"): self._auto_fetch_running = False - - if hasattr(self, '_auto_fetch_task') and self._auto_fetch_task and not self._auto_fetch_task.done(): + + if ( + hasattr(self, "_auto_fetch_task") + and self._auto_fetch_task + and not self._auto_fetch_task.done() + ): self._auto_fetch_task.cancel() try: - await self._auto_fetch_task # type: ignore + await self._auto_fetch_task # type: ignore except asyncio.CancelledError: pass self._auto_fetch_task = None - - async def ensure_contacts(self, follow=False): + + async def ensure_contacts(self, follow: bool = False) -> bool: """Ensure contacts are fetched""" - if not self._contacts or (follow and self._contacts_dirty) : - await self.commands.get_contacts(lastmod = self._lastmod) + if not self._contacts or (follow and self._contacts_dirty): + await self.commands.get_contacts(lastmod=self._lastmod) return True return False diff --git a/src/meshcore/packets.py b/src/meshcore/packets.py index f8a9e0e..00b5b27 100644 --- a/src/meshcore/packets.py +++ b/src/meshcore/packets.py @@ -1,5 +1,6 @@ from enum import Enum + # Packet prefixes for the protocol class PacketType(Enum): OK = 0 @@ -26,7 +27,7 @@ class PacketType(Enum): CUSTOM_VARS = 21 BINARY_REQ = 50 FACTORY_RESET = 51 - + # Push notifications ADVERTISEMENT = 0x80 PATH_UPDATE = 0x81 diff --git a/src/meshcore/reader.py b/src/meshcore/reader.py index 81713e9..9155784 100644 --- a/src/meshcore/reader.py +++ b/src/meshcore/reader.py @@ -1,8 +1,6 @@ -import sys import logging -import asyncio import json -from typing import Any, Optional, Dict +from typing import Any, Dict from .events import Event, EventType, EventDispatcher from .packets import PacketType from cayennelpp import LppFrame, LppData @@ -18,35 +16,37 @@ class MessageReader: # before events are dispatched self.contacts = {} # Temporary storage during contact list building self.contact_nb = 0 # Used for contact processing - + async def handle_rx(self, data: bytearray): packet_type_value = data[0] logger.debug(f"Received data: {data.hex()}") - + # Handle command responses if packet_type_value == PacketType.OK.value: result: Dict[str, Any] = {} if len(data) == 5: - result["value"] = int.from_bytes(data[1:5], byteorder='little') - + result["value"] = int.from_bytes(data[1:5], byteorder="little") + # Dispatch event for the OK response await self.dispatcher.dispatch(Event(EventType.OK, result)) - + elif packet_type_value == PacketType.ERROR.value: if len(data) > 1: result = {"error_code": data[1]} else: result = {} - + # Dispatch event for the ERROR response await self.dispatcher.dispatch(Event(EventType.ERROR, result)) - + elif packet_type_value == PacketType.CONTACT_START.value: - self.contact_nb = int.from_bytes(data[1:5], byteorder='little') + self.contact_nb = int.from_bytes(data[1:5], byteorder="little") self.contacts = {} - - elif packet_type_value == PacketType.CONTACT.value or\ - packet_type_value == PacketType.PUSH_CODE_NEW_ADVERT.value: + + elif ( + packet_type_value == PacketType.CONTACT.value + or packet_type_value == PacketType.PUSH_CODE_NEW_ADVERT.value + ): c = {} c["public_key"] = data[1:33].hex() c["type"] = data[33] @@ -55,195 +55,215 @@ class MessageReader: plen = int.from_bytes(data[35:36], signed=True) if plen == -1: plen = 0 - c["out_path"] = data[36:36+plen].hex() - c["adv_name"] = data[100:132].decode('utf-8', 'ignore').replace("\0","") - c["last_advert"] = int.from_bytes(data[132:136], byteorder='little') - c["adv_lat"] = int.from_bytes(data[136:140], byteorder='little',signed=True)/1e6 - c["adv_lon"] = int.from_bytes(data[140:144], byteorder='little',signed=True)/1e6 - c["lastmod"] = int.from_bytes(data[144:148], byteorder='little') + c["out_path"] = data[36 : 36 + plen].hex() + c["adv_name"] = data[100:132].decode("utf-8", "ignore").replace("\0", "") + c["last_advert"] = int.from_bytes(data[132:136], byteorder="little") + c["adv_lat"] = ( + int.from_bytes(data[136:140], byteorder="little", signed=True) / 1e6 + ) + c["adv_lon"] = ( + int.from_bytes(data[140:144], byteorder="little", signed=True) / 1e6 + ) + c["lastmod"] = int.from_bytes(data[144:148], byteorder="little") - if packet_type_value == PacketType.PUSH_CODE_NEW_ADVERT.value : + if packet_type_value == PacketType.PUSH_CODE_NEW_ADVERT.value: await self.dispatcher.dispatch(Event(EventType.NEW_CONTACT, c)) else: self.contacts[c["public_key"]] = c - + elif packet_type_value == PacketType.CONTACT_END.value: - lastmod = int.from_bytes(data[1:5], byteorder='little') + lastmod = int.from_bytes(data[1:5], byteorder="little") attributes = { "lastmod": lastmod, } - await self.dispatcher.dispatch(Event(EventType.CONTACTS, self.contacts, attributes)) - + await self.dispatcher.dispatch( + Event(EventType.CONTACTS, self.contacts, attributes) + ) + elif packet_type_value == PacketType.SELF_INFO.value: self_info = {} self_info["adv_type"] = data[1] self_info["tx_power"] = data[2] self_info["max_tx_power"] = data[3] self_info["public_key"] = data[4:36].hex() - self_info["adv_lat"] = int.from_bytes(data[36:40], byteorder='little', signed=True)/1e6 - self_info["adv_lon"] = int.from_bytes(data[40:44], byteorder='little', signed=True)/1e6 + self_info["adv_lat"] = ( + int.from_bytes(data[36:40], byteorder="little", signed=True) / 1e6 + ) + self_info["adv_lon"] = ( + int.from_bytes(data[40:44], byteorder="little", signed=True) / 1e6 + ) self_info["adv_loc_policy"] = data[45] self_info["telemetry_mode_env"] = (data[46] >> 4) & 0b11 self_info["telemetry_mode_loc"] = (data[46] >> 2) & 0b11 self_info["telemetry_mode_base"] = (data[46]) & 0b11 self_info["manual_add_contacts"] = data[47] > 0 - self_info["radio_freq"] = int.from_bytes(data[48:52], byteorder='little') / 1000 - self_info["radio_bw"] = int.from_bytes(data[52:56], byteorder='little') / 1000 + self_info["radio_freq"] = ( + int.from_bytes(data[48:52], byteorder="little") / 1000 + ) + self_info["radio_bw"] = ( + int.from_bytes(data[52:56], byteorder="little") / 1000 + ) self_info["radio_sf"] = data[56] self_info["radio_cr"] = data[57] - self_info["name"] = data[58:].decode('utf-8', 'ignore') + self_info["name"] = data[58:].decode("utf-8", "ignore") await self.dispatcher.dispatch(Event(EventType.SELF_INFO, self_info)) - + elif packet_type_value == PacketType.MSG_SENT.value: res = {} res["type"] = data[1] res["expected_ack"] = bytes(data[2:6]) - res["suggested_timeout"] = int.from_bytes(data[6:10], byteorder='little') - + res["suggested_timeout"] = int.from_bytes(data[6:10], byteorder="little") + attributes = { "type": res["type"], - "expected_ack": res["expected_ack"].hex() + "expected_ack": res["expected_ack"].hex(), } - + await self.dispatcher.dispatch(Event(EventType.MSG_SENT, res, attributes)) - + elif packet_type_value == PacketType.CONTACT_MSG_RECV.value: res = {} res["type"] = "PRIV" res["pubkey_prefix"] = data[1:7].hex() res["path_len"] = data[7] res["txt_type"] = data[8] - res["sender_timestamp"] = int.from_bytes(data[9:13], byteorder='little') + res["sender_timestamp"] = int.from_bytes(data[9:13], byteorder="little") if data[8] == 2: res["signature"] = data[13:17].hex() - res["text"] = data[17:].decode('utf-8', 'ignore') + res["text"] = data[17:].decode("utf-8", "ignore") else: - res["text"] = data[13:].decode('utf-8', 'ignore') - + res["text"] = data[13:].decode("utf-8", "ignore") + attributes = { "pubkey_prefix": res["pubkey_prefix"], - "txt_type": res["txt_type"] + "txt_type": res["txt_type"], } - + evt_type = EventType.CONTACT_MSG_RECV await self.dispatcher.dispatch(Event(evt_type, res, attributes)) - + elif packet_type_value == 16: # A reply to CMD_SYNC_NEXT_MESSAGE (ver >= 3) res = {} res["type"] = "PRIV" - res["SNR"] = int.from_bytes(data[1:2], byteorder='little', signed=True) / 4 + res["SNR"] = int.from_bytes(data[1:2], byteorder="little", signed=True) / 4 res["pubkey_prefix"] = data[4:10].hex() res["path_len"] = data[10] res["txt_type"] = data[11] - res["sender_timestamp"] = int.from_bytes(data[12:16], byteorder='little') + res["sender_timestamp"] = int.from_bytes(data[12:16], byteorder="little") if data[11] == 2: res["signature"] = data[16:20].hex() - res["text"] = data[20:].decode('utf-8', 'ignore') + res["text"] = data[20:].decode("utf-8", "ignore") else: - res["text"] = data[16:].decode('utf-8', 'ignore') - + res["text"] = data[16:].decode("utf-8", "ignore") + attributes = { "pubkey_prefix": res["pubkey_prefix"], - "txt_type": res["txt_type"] + "txt_type": res["txt_type"], } - - await self.dispatcher.dispatch(Event(EventType.CONTACT_MSG_RECV, res, attributes)) - + + await self.dispatcher.dispatch( + Event(EventType.CONTACT_MSG_RECV, res, attributes) + ) + elif packet_type_value == PacketType.CHANNEL_MSG_RECV.value: res = {} res["type"] = "CHAN" res["channel_idx"] = data[1] res["path_len"] = data[2] res["txt_type"] = data[3] - res["sender_timestamp"] = int.from_bytes(data[4:8], byteorder='little') - res["text"] = data[8:].decode('utf-8', 'ignore') - + res["sender_timestamp"] = int.from_bytes(data[4:8], byteorder="little") + res["text"] = data[8:].decode("utf-8", "ignore") + attributes = { "channel_idx": res["channel_idx"], - "txt_type": res["txt_type"] + "txt_type": res["txt_type"], } - - await self.dispatcher.dispatch(Event(EventType.CHANNEL_MSG_RECV, res, attributes)) - + + await self.dispatcher.dispatch( + Event(EventType.CHANNEL_MSG_RECV, res, attributes) + ) + elif packet_type_value == 17: # A reply to CMD_SYNC_NEXT_MESSAGE (ver >= 3) res = {} res["type"] = "CHAN" - res["SNR"] = int.from_bytes(data[1:2], byteorder='little', signed=True) / 4 + res["SNR"] = int.from_bytes(data[1:2], byteorder="little", signed=True) / 4 res["channel_idx"] = data[4] res["path_len"] = data[5] res["txt_type"] = data[6] - res["sender_timestamp"] = int.from_bytes(data[7:11], byteorder='little') - res["text"] = data[11:].decode('utf-8', 'ignore') - + res["sender_timestamp"] = int.from_bytes(data[7:11], byteorder="little") + res["text"] = data[11:].decode("utf-8", "ignore") + attributes = { "channel_idx": res["channel_idx"], - "txt_type": res["txt_type"] + "txt_type": res["txt_type"], } - - await self.dispatcher.dispatch(Event(EventType.CHANNEL_MSG_RECV, res, attributes)) - + + await self.dispatcher.dispatch( + Event(EventType.CHANNEL_MSG_RECV, res, attributes) + ) + elif packet_type_value == PacketType.CURRENT_TIME.value: - time_value = int.from_bytes(data[1:5], byteorder='little') + time_value = int.from_bytes(data[1:5], byteorder="little") result = {"time": time_value} await self.dispatcher.dispatch(Event(EventType.CURRENT_TIME, result)) - + elif packet_type_value == PacketType.NO_MORE_MSGS.value: result = {"messages_available": False} await self.dispatcher.dispatch(Event(EventType.NO_MORE_MSGS, result)) - + elif packet_type_value == PacketType.CONTACT_URI.value: contact_uri = "meshcore://" + data[1:].hex() result = {"uri": contact_uri} await self.dispatcher.dispatch(Event(EventType.CONTACT_URI, result)) - + elif packet_type_value == PacketType.BATTERY.value: - battery_level = int.from_bytes(data[1:3], byteorder='little') + battery_level = int.from_bytes(data[1:3], byteorder="little") result = {"level": battery_level} - if len(data) > 3 : # has storage info as well - result["used_kb"] = int.from_bytes(data[3:7], byteorder='little') - result["total_kb"] = int.from_bytes(data[7:11], byteorder='little') + if len(data) > 3: # has storage info as well + result["used_kb"] = int.from_bytes(data[3:7], byteorder="little") + result["total_kb"] = int.from_bytes(data[7:11], byteorder="little") await self.dispatcher.dispatch(Event(EventType.BATTERY, result)) - + elif packet_type_value == PacketType.DEVICE_INFO.value: res = {} res["fw ver"] = data[1] if data[1] >= 3: res["max_contacts"] = data[2] * 2 res["max_channels"] = data[3] - res["ble_pin"] = int.from_bytes(data[4:8], byteorder='little') - res["fw_build"] = data[8:20].decode('utf-8', 'ignore').replace("\0","") - res["model"] = data[20:60].decode('utf-8', 'ignore').replace("\0","") - res["ver"] = data[60:80].decode('utf-8', 'ignore').replace("\0","") + res["ble_pin"] = int.from_bytes(data[4:8], byteorder="little") + res["fw_build"] = data[8:20].decode("utf-8", "ignore").replace("\0", "") + res["model"] = data[20:60].decode("utf-8", "ignore").replace("\0", "") + res["ver"] = data[60:80].decode("utf-8", "ignore").replace("\0", "") await self.dispatcher.dispatch(Event(EventType.DEVICE_INFO, res)) - + elif packet_type_value == PacketType.CUSTOM_VARS.value: logger.debug(f"received custom vars response: {data.hex()}") res = {} - rawdata = data[1:].decode('utf-8', 'ignore') - if not rawdata == "" : + rawdata = data[1:].decode("utf-8", "ignore") + if not rawdata == "": pairs = rawdata.split(",") - for p in pairs : + for p in pairs: psplit = p.split(":") res[psplit[0]] = psplit[1] logger.debug(f"got custom vars : {res}") await self.dispatcher.dispatch(Event(EventType.CUSTOM_VARS, res)) - + elif packet_type_value == PacketType.CHANNEL_INFO.value: logger.debug(f"received channel info response: {data.hex()}") res = {} res["channel_idx"] = data[1] - + # Channel name is null-terminated, so find the first null byte name_bytes = data[2:34] null_pos = name_bytes.find(0) if null_pos >= 0: - res["channel_name"] = name_bytes[:null_pos].decode('utf-8', 'ignore') + res["channel_name"] = name_bytes[:null_pos].decode("utf-8", "ignore") else: - res["channel_name"] = name_bytes.decode('utf-8', 'ignore') - + res["channel_name"] = name_bytes.decode("utf-8", "ignore") + res["channel_secret"] = data[34:50] - await self.dispatcher.dispatch(Event(EventType.CHANNEL_INFO, res, res)) + await self.dispatcher.dispatch(Event(EventType.CHANNEL_INFO, res, res)) # Push notifications elif packet_type_value == PacketType.ADVERTISEMENT.value: @@ -251,30 +271,28 @@ class MessageReader: res = {} res["public_key"] = data[1:33].hex() await self.dispatcher.dispatch(Event(EventType.ADVERTISEMENT, res, res)) - + elif packet_type_value == PacketType.PATH_UPDATE.value: logger.debug("Code path update") res = {} res["public_key"] = data[1:33].hex() await self.dispatcher.dispatch(Event(EventType.PATH_UPDATE, res, res)) - + elif packet_type_value == PacketType.ACK.value: logger.debug("Received ACK") ack_data = {} - + if len(data) >= 5: ack_data["code"] = bytes(data[1:5]).hex() - - attributes = { - "code": ack_data.get("code", "") - } - + + attributes = {"code": ack_data.get("code", "")} + await self.dispatcher.dispatch(Event(EventType.ACK, ack_data, attributes)) - + elif packet_type_value == PacketType.MESSAGES_WAITING.value: logger.debug("Msgs are waiting") await self.dispatcher.dispatch(Event(EventType.MESSAGES_WAITING, {})) - + elif packet_type_value == PacketType.RAW_DATA.value: res = {} res["SNR"] = data[1] / 4 @@ -283,145 +301,161 @@ class MessageReader: logger.debug("Received raw data") print(res) await self.dispatcher.dispatch(Event(EventType.RAW_DATA, res)) - + elif packet_type_value == PacketType.LOGIN_SUCCESS.value: res = {} if len(data) > 1: res["permissions"] = data[1] res["is_admin"] = (data[1] & 1) == 1 # Check if admin bit is set - + if len(data) > 7: res["pubkey_prefix"] = data[2:8].hex() - - attributes = { - "pubkey_prefix": res.get("pubkey_prefix") - } - - await self.dispatcher.dispatch(Event(EventType.LOGIN_SUCCESS, res, attributes)) - + + attributes = {"pubkey_prefix": res.get("pubkey_prefix")} + + await self.dispatcher.dispatch( + Event(EventType.LOGIN_SUCCESS, res, attributes) + ) + elif packet_type_value == PacketType.LOGIN_FAILED.value: res = {} - + if len(data) > 7: res["pubkey_prefix"] = data[2:8].hex() - - attributes = { - "pubkey_prefix": res.get("pubkey_prefix") - } - - await self.dispatcher.dispatch(Event(EventType.LOGIN_FAILED, res, attributes)) - + + attributes = {"pubkey_prefix": res.get("pubkey_prefix")} + + await self.dispatcher.dispatch( + Event(EventType.LOGIN_FAILED, res, attributes) + ) + elif packet_type_value == PacketType.STATUS_RESPONSE.value: res = {} res["pubkey_pre"] = data[2:8].hex() - res["bat"] = int.from_bytes(data[8:10], byteorder='little') - res["tx_queue_len"] = int.from_bytes(data[10:12], byteorder='little') - res["noise_floor"] = int.from_bytes(data[12:14], byteorder='little', signed=True) - res["last_rssi"] = int.from_bytes(data[14:16], byteorder='little', signed=True) - res["nb_recv"] = int.from_bytes(data[16:20], byteorder='little', signed=False) - res["nb_sent"] = int.from_bytes(data[20:24], byteorder='little', signed=False) - res["airtime"] = int.from_bytes(data[24:28], byteorder='little') - res["uptime"] = int.from_bytes(data[28:32], byteorder='little') - res["sent_flood"] = int.from_bytes(data[32:36], byteorder='little') - res["sent_direct"] = int.from_bytes(data[36:40], byteorder='little') - res["recv_flood"] = int.from_bytes(data[40:44], byteorder='little') - res["recv_direct"] = int.from_bytes(data[44:48], byteorder='little') - res["full_evts"] = int.from_bytes(data[48:50], byteorder='little') - res["last_snr"] = int.from_bytes(data[50:52], byteorder='little', signed=True) / 4 - res["direct_dups"] = int.from_bytes(data[52:54], byteorder='little') - res["flood_dups"] = int.from_bytes(data[54:56], byteorder='little') - res["rx_airtime"] = int.from_bytes(data[56:60], byteorder='little') + res["bat"] = int.from_bytes(data[8:10], byteorder="little") + res["tx_queue_len"] = int.from_bytes(data[10:12], byteorder="little") + res["noise_floor"] = int.from_bytes( + data[12:14], byteorder="little", signed=True + ) + res["last_rssi"] = int.from_bytes( + data[14:16], byteorder="little", signed=True + ) + res["nb_recv"] = int.from_bytes( + data[16:20], byteorder="little", signed=False + ) + res["nb_sent"] = int.from_bytes( + data[20:24], byteorder="little", signed=False + ) + res["airtime"] = int.from_bytes(data[24:28], byteorder="little") + res["uptime"] = int.from_bytes(data[28:32], byteorder="little") + res["sent_flood"] = int.from_bytes(data[32:36], byteorder="little") + res["sent_direct"] = int.from_bytes(data[36:40], byteorder="little") + res["recv_flood"] = int.from_bytes(data[40:44], byteorder="little") + res["recv_direct"] = int.from_bytes(data[44:48], byteorder="little") + res["full_evts"] = int.from_bytes(data[48:50], byteorder="little") + res["last_snr"] = ( + int.from_bytes(data[50:52], byteorder="little", signed=True) / 4 + ) + res["direct_dups"] = int.from_bytes(data[52:54], byteorder="little") + res["flood_dups"] = int.from_bytes(data[54:56], byteorder="little") + res["rx_airtime"] = int.from_bytes(data[56:60], byteorder="little") data_hex = data[8:].hex() logger.debug(f"Status response: {data_hex}") - + attributes = { "pubkey_prefix": res["pubkey_pre"], } - await self.dispatcher.dispatch(Event(EventType.STATUS_RESPONSE, res, attributes)) - + await self.dispatcher.dispatch( + Event(EventType.STATUS_RESPONSE, res, attributes) + ) + elif packet_type_value == PacketType.LOG_DATA.value: logger.debug(f"Received RF log data: {data.hex()}") - + # Parse as raw RX data - log_data: Dict[str, Any] = { - "raw_hex": data[1:].hex() - } - + log_data: Dict[str, Any] = {"raw_hex": data[1:].hex()} + # First byte is SNR (signed byte, multiplied by 4) if len(data) > 1: snr_byte = data[1] # Convert to signed value snr = (snr_byte if snr_byte < 128 else snr_byte - 256) / 4.0 log_data["snr"] = snr - + # Second byte is RSSI (signed byte) if len(data) > 2: rssi_byte = data[2] # Convert to signed value rssi = rssi_byte if rssi_byte < 128 else rssi_byte - 256 log_data["rssi"] = rssi - + # Remaining bytes are the raw data payload if len(data) > 3: log_data["payload"] = data[3:].hex() log_data["payload_length"] = len(data) - 3 - + attributes = { "pubkey_prefix": log_data["raw_hex"], } - + # Dispatch as RF log data - await self.dispatcher.dispatch(Event(EventType.RX_LOG_DATA, log_data, attributes)) - + await self.dispatcher.dispatch( + Event(EventType.RX_LOG_DATA, log_data, attributes) + ) + elif packet_type_value == PacketType.TRACE_DATA.value: logger.debug(f"Received trace data: {data.hex()}") res = {} - + # According to the source, format is: # 0x89, reserved(0), path_len, flags, tag(4), auth(4), path_hashes[], path_snrs[], final_snr - - reserved = data[1] + path_len = data[2] flags = data[3] - tag = int.from_bytes(data[4:8], byteorder='little') - auth_code = int.from_bytes(data[8:12], byteorder='little') - + tag = int.from_bytes(data[4:8], byteorder="little") + auth_code = int.from_bytes(data[8:12], byteorder="little") + # Initialize result res["tag"] = tag res["auth"] = auth_code res["flags"] = flags res["path_len"] = path_len - + # Process path as array of objects with hash and SNR path_nodes = [] - - if path_len > 0 and len(data) >= 12 + path_len*2 + 1: + + if path_len > 0 and len(data) >= 12 + path_len * 2 + 1: # Extract path with hash and SNR pairs for i in range(path_len): node = { "hash": f"{data[12+i]:02x}", # SNR is stored as a signed byte representing SNR * 4 - "snr": (data[12+path_len+i] if data[12+path_len+i] < 128 else data[12+path_len+i] - 256) / 4.0 + "snr": ( + data[12 + path_len + i] + if data[12 + path_len + i] < 128 + else data[12 + path_len + i] - 256 + ) + / 4.0, } path_nodes.append(node) - + # Add the final node (our device) with its SNR - final_snr_byte = data[12+path_len*2] - final_snr = (final_snr_byte if final_snr_byte < 128 else final_snr_byte - 256) / 4.0 - path_nodes.append({ - "snr": final_snr - }) - + final_snr_byte = data[12 + path_len * 2] + final_snr = ( + final_snr_byte if final_snr_byte < 128 else final_snr_byte - 256 + ) / 4.0 + path_nodes.append({"snr": final_snr}) + res["path"] = path_nodes - + logger.debug(f"Parsed trace data: {res}") - + attributes = { "tag": res["tag"], "auth_code": res["auth"], } - + await self.dispatcher.dispatch(Event(EventType.TRACE_DATA, res, attributes)) elif packet_type_value == PacketType.TELEMETRY_RESPONSE.value: @@ -439,15 +473,19 @@ class MessageReader: lpp_data_list.append(lppdata) i = i + len(lppdata) - lpp = json.loads(json.dumps(LppFrame(lpp_data_list), default=lpp_json_encoder)) + lpp = json.loads( + json.dumps(LppFrame(lpp_data_list), default=lpp_json_encoder) + ) res["lpp"] = lpp attributes = { - "raw" : buf.hex(), + "raw": buf.hex(), } - - await self.dispatcher.dispatch(Event(EventType.TELEMETRY_RESPONSE, res, attributes)) + + await self.dispatcher.dispatch( + Event(EventType.TELEMETRY_RESPONSE, res, attributes) + ) elif packet_type_value == PacketType.BINARY_RESPONSE.value: logger.debug(f"Received binary data: {data.hex()}") @@ -456,11 +494,11 @@ class MessageReader: res["tag"] = data[2:6].hex() res["data"] = data[6:].hex() - attributes = { - "tag" : res["tag"] - } + attributes = {"tag": res["tag"]} - await self.dispatcher.dispatch(Event(EventType.BINARY_RESPONSE, res, attributes)) + await self.dispatcher.dispatch( + Event(EventType.BINARY_RESPONSE, res, attributes) + ) elif packet_type_value == PacketType.PATH_DISCOVERY_RESPONSE.value: logger.debug(f"Received path discovery response: {data.hex()}") @@ -468,18 +506,17 @@ class MessageReader: res["pubkey_pre"] = data[2:8].hex() opl = data[8] res["out_path_len"] = opl - res["out_path"] = data[9:9+opl].hex() - ipl = data[9+opl] + res["out_path"] = data[9 : 9 + opl].hex() + ipl = data[9 + opl] res["in_path_len"] = ipl - res["in_path"] = data[10+opl:10+opl+ipl].hex() + res["in_path"] = data[10 + opl : 10 + opl + ipl].hex() - attributes = { - "pubkey_pre" : res["pubkey_pre"] - } + attributes = {"pubkey_pre": res["pubkey_pre"]} - await self.dispatcher.dispatch(Event(EventType.PATH_RESPONSE, res, attributes)) + await self.dispatcher.dispatch( + Event(EventType.PATH_RESPONSE, res, attributes) + ) else: logger.debug(f"Unhandled data received {data}") logger.debug(f"Unhandled packet type: {packet_type_value}") - diff --git a/src/meshcore/serial_cx.py b/src/meshcore/serial_cx.py index 172baf4..ab48d2b 100644 --- a/src/meshcore/serial_cx.py +++ b/src/meshcore/serial_cx.py @@ -1,6 +1,7 @@ -""" - mccli.py : CLI interface to MeschCore BLE companion app """ +mccli.py : CLI interface to MeschCore BLE companion app +""" + import asyncio import logging import serial_asyncio @@ -8,6 +9,7 @@ import serial_asyncio # Get logger logger = logging.getLogger("meshcore") + class SerialConnection: def __init__(self, port, baudrate, cx_dly=0.2): self.port = port @@ -27,23 +29,28 @@ class SerialConnection: def connection_made(self, transport): self.cx.transport = transport - logger.debug('port opened') - if isinstance(transport, serial_asyncio.SerialTransport) and transport.serial: - transport.serial.rts = False # You can manipulate Serial object via transport - + logger.debug("port opened") + if ( + isinstance(transport, serial_asyncio.SerialTransport) + and transport.serial + ): + transport.serial.rts = ( + False # You can manipulate Serial object via transport + ) + def data_received(self, data): - self.cx.handle_rx(data) - + self.cx.handle_rx(data) + def connection_lost(self, exc): - logger.debug('Serial port closed') + logger.debug("Serial port closed") if self.cx._disconnect_callback: asyncio.create_task(self.cx._disconnect_callback("serial_disconnect")) - + def pause_writing(self): - logger.debug('pause writing') - + logger.debug("pause writing") + def resume_writing(self): - logger.debug('resume writing') + logger.debug("resume writing") async def connect(self): """ @@ -51,39 +58,42 @@ class SerialConnection: """ loop = asyncio.get_running_loop() await serial_asyncio.create_serial_connection( - loop, lambda: self.MCSerialClientProtocol(self), - self.port, baudrate=self.baudrate) + loop, + lambda: self.MCSerialClientProtocol(self), + self.port, + baudrate=self.baudrate, + ) - await asyncio.sleep(self.cx_dly) # wait for cx to establish + await asyncio.sleep(self.cx_dly) # wait for cx to establish logger.info("Serial Connection started") return self.port - def set_reader(self, reader) : + def set_reader(self, reader): self.reader = reader def handle_rx(self, data: bytearray): headerlen = len(self.header) framelen = len(self.inframe) - if not self.frame_started : # wait start of frame + if not self.frame_started: # wait start of frame if len(data) >= 3 - headerlen: - self.header = self.header + data[:3-headerlen] + self.header = self.header + data[: 3 - headerlen] self.frame_started = True - self.frame_size = int.from_bytes(self.header[1:], byteorder='little') - self.handle_rx(data[3-headerlen:]) + self.frame_size = int.from_bytes(self.header[1:], byteorder="little") + self.handle_rx(data[3 - headerlen :]) else: self.header = self.header + data else: if framelen + len(data) < self.frame_size: self.inframe = self.inframe + data else: - self.inframe = self.inframe + data[:self.frame_size-framelen] - if not self.reader is None: + self.inframe = self.inframe + data[: self.frame_size - framelen] + if self.reader is not None: asyncio.create_task(self.reader.handle_rx(self.inframe)) self.frame_started = False self.header = b"" self.inframe = b"" if framelen + len(data) > self.frame_size: - self.handle_rx(data[self.frame_size-framelen:]) + self.handle_rx(data[self.frame_size - framelen :]) async def send(self, data): if not self.transport: @@ -93,14 +103,14 @@ class SerialConnection: pkt = b"\x3c" + size.to_bytes(2, byteorder="little") + data logger.debug(f"sending pkt : {pkt}") self.transport.write(pkt) - + async def disconnect(self): """Close the serial connection.""" if self.transport: self.transport.close() self.transport = None logger.debug("Serial Connection closed") - + def set_disconnect_callback(self, callback): """Set callback to handle disconnections.""" self._disconnect_callback = callback diff --git a/src/meshcore/tcp_cx.py b/src/meshcore/tcp_cx.py index 8551b9a..81e7cb9 100644 --- a/src/meshcore/tcp_cx.py +++ b/src/meshcore/tcp_cx.py @@ -1,6 +1,7 @@ -""" - mccli.py : CLI interface to MeschCore BLE companion app """ +mccli.py : CLI interface to MeschCore BLE companion app +""" + import asyncio import logging @@ -10,6 +11,7 @@ logger = logging.getLogger("meshcore") # TCP disconnect detection threshold TCP_DISCONNECT_THRESHOLD = 5 + class TCPConnection: def __init__(self, host, port): self.host = host @@ -32,18 +34,18 @@ class TCPConnection: # Reset counters on new connection self.cx._send_count = 0 self.cx._receive_count = 0 - logger.debug('connection established') - + logger.debug("connection established") + def data_received(self, data): - logger.debug('data received') + logger.debug("data received") self.cx._receive_count += 1 self.cx.handle_rx(data) def error_received(self, exc): - logger.error(f'Error received: {exc}') - + logger.error(f"Error received: {exc}") + def connection_lost(self, exc): - logger.debug('TCP server closed the connection') + logger.debug("TCP server closed the connection") if self.cx._disconnect_callback: asyncio.create_task(self.cx._disconnect_callback("tcp_disconnect")) @@ -53,41 +55,41 @@ class TCPConnection: """ loop = asyncio.get_running_loop() await loop.create_connection( - lambda: self.MCClientProtocol(self), - self.host, self.port) + lambda: self.MCClientProtocol(self), self.host, self.port + ) logger.info("TCP Connection started") future = asyncio.Future() future.set_result(self.host) - + return future - def set_reader(self, reader) : + def set_reader(self, reader): self.reader = reader def handle_rx(self, data: bytearray): headerlen = len(self.header) framelen = len(self.inframe) - if not self.frame_started : # wait start of frame + if not self.frame_started: # wait start of frame if len(data) >= 3 - headerlen: - self.header = self.header + data[:3-headerlen] + self.header = self.header + data[: 3 - headerlen] self.frame_started = True - self.frame_size = int.from_bytes(self.header[1:], byteorder='little') - self.handle_rx(data[3-headerlen:]) + self.frame_size = int.from_bytes(self.header[1:], byteorder="little") + self.handle_rx(data[3 - headerlen :]) else: self.header = self.header + data else: if framelen + len(data) < self.frame_size: self.inframe = self.inframe + data else: - self.inframe = self.inframe + data[:self.frame_size-framelen] - if not self.reader is None: + self.inframe = self.inframe + data[: self.frame_size - framelen] + if self.reader is not None: asyncio.create_task(self.reader.handle_rx(self.inframe)) self.frame_started = False self.header = b"" self.inframe = b"" if framelen + len(data) > self.frame_size: - self.handle_rx(data[self.frame_size-framelen:]) + self.handle_rx(data[self.frame_size - framelen :]) async def send(self, data): if not self.transport: @@ -95,28 +97,30 @@ class TCPConnection: if self._disconnect_callback: await self._disconnect_callback("tcp_transport_lost") return - + self._send_count += 1 - + # Check if we've sent packets without any responses if self._send_count - self._receive_count >= TCP_DISCONNECT_THRESHOLD: - logger.debug(f"TCP disconnect detected: sent {self._send_count}, received {self._receive_count}") + logger.debug( + f"TCP disconnect detected: sent {self._send_count}, received {self._receive_count}" + ) if self._disconnect_callback: await self._disconnect_callback("tcp_no_response") return - + size = len(data) pkt = b"\x3c" + size.to_bytes(2, byteorder="little") + data logger.debug(f"sending pkt : {pkt}") self.transport.write(pkt) - + async def disconnect(self): """Close the TCP connection.""" if self.transport: self.transport.close() self.transport = None logger.debug("TCP Connection closed") - + def set_disconnect_callback(self, callback): """Set callback to handle disconnections.""" self._disconnect_callback = callback diff --git a/tests/test_ble_connection.py b/tests/test_ble_connection.py index dcd7c56..089b6f1 100644 --- a/tests/test_ble_connection.py +++ b/tests/test_ble_connection.py @@ -2,10 +2,15 @@ import asyncio import unittest from unittest.mock import AsyncMock, MagicMock, patch -from meshcore.ble_cx import BLEConnection, UART_SERVICE_UUID, UART_TX_CHAR_UUID, UART_RX_CHAR_UUID +from meshcore.ble_cx import ( + BLEConnection, + UART_TX_CHAR_UUID, + UART_RX_CHAR_UUID, +) + class TestBLEConnection(unittest.TestCase): - @patch('meshcore.ble_cx.BleakClient') + @patch("meshcore.ble_cx.BleakClient") def test_ble_connection_and_disconnection(self, mock_bleak_client): """ Tests the BLEConnection class for connecting and disconnecting from a BLE device. @@ -13,7 +18,7 @@ class TestBLEConnection(unittest.TestCase): # Arrange mock_client_instance = self._get_mock_bleak_client() mock_bleak_client.return_value = mock_client_instance - + address = "00:11:22:33:44:55" ble_conn = BLEConnection(address=address) @@ -23,10 +28,12 @@ class TestBLEConnection(unittest.TestCase): # Assert mock_client_instance.connect.assert_called_once() - mock_client_instance.start_notify.assert_called_once_with(UART_TX_CHAR_UUID, ble_conn.handle_rx) + mock_client_instance.start_notify.assert_called_once_with( + UART_TX_CHAR_UUID, ble_conn.handle_rx + ) mock_client_instance.disconnect.assert_called_once() - @patch('meshcore.ble_cx.BleakClient') + @patch("meshcore.ble_cx.BleakClient") def test_send_data(self, mock_bleak_client): """ Tests the send method of the BLEConnection class. @@ -34,7 +41,7 @@ class TestBLEConnection(unittest.TestCase): # Arrange mock_client_instance = self._get_mock_bleak_client() mock_bleak_client.return_value = mock_client_instance - + address = "00:11:22:33:44:55" ble_conn = BLEConnection(address=address) asyncio.run(ble_conn.connect()) @@ -44,7 +51,9 @@ class TestBLEConnection(unittest.TestCase): asyncio.run(ble_conn.send(data_to_send)) # Assert - ble_conn.rx_char.write_gatt_char.assert_called_once_with(ble_conn.rx_char, data_to_send, response=False) + ble_conn.rx_char.write_gatt_char.assert_called_once_with( + ble_conn.rx_char, data_to_send, response=False + ) def _get_mock_bleak_client(self): """ @@ -60,12 +69,13 @@ class TestBLEConnection(unittest.TestCase): mock_service = MagicMock() mock_char = MagicMock() mock_char.uuid = UART_RX_CHAR_UUID - mock_char.write_gatt_char = mock_client.write_gatt_char - + mock_char.write_gatt_char = mock_client.write_gatt_char + mock_service.get_characteristic.return_value = mock_char mock_client.services.get_service.return_value = mock_service - + return mock_client -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main() diff --git a/tests/unit/test_commands.py b/tests/unit/test_commands.py index 6b96211..91d531b 100644 --- a/tests/unit/test_commands.py +++ b/tests/unit/test_commands.py @@ -1,11 +1,12 @@ import pytest import asyncio -from unittest.mock import MagicMock, patch, AsyncMock +from unittest.mock import MagicMock, AsyncMock from meshcore.commands import CommandHandler from meshcore.events import EventType, Event pytestmark = pytest.mark.asyncio + # Fixtures @pytest.fixture def mock_connection(): @@ -13,6 +14,7 @@ def mock_connection(): connection.send = AsyncMock() return connection + @pytest.fixture def mock_dispatcher(): dispatcher = MagicMock() @@ -20,29 +22,36 @@ def mock_dispatcher(): dispatcher.dispatch = AsyncMock() return dispatcher + @pytest.fixture def command_handler(mock_connection, mock_dispatcher): handler = CommandHandler() - + async def sender(data): await mock_connection.send(data) + handler._sender_func = sender - + handler.dispatcher = mock_dispatcher return handler + # Test helper def setup_event_response(mock_dispatcher, event_type, payload, attribute_filters=None): async def wait_response(requested_type, filters=None, timeout=None): if requested_type == event_type: if filters and attribute_filters: - if not all(attribute_filters.get(key) == value for key, value in filters.items()): + if not all( + attribute_filters.get(key) == value + for key, value in filters.items() + ): return None return Event(event_type, payload) return None - + mock_dispatcher.wait_for_event.side_effect = wait_response + # Basic tests async def test_send_basic(command_handler, mock_connection): result = await command_handler.send(b"test_data") @@ -50,141 +59,163 @@ async def test_send_basic(command_handler, mock_connection): assert result.type == EventType.OK assert result.payload == {} + async def test_send_with_event(command_handler, mock_connection, mock_dispatcher): expected_payload = {"value": 42} setup_event_response(mock_dispatcher, EventType.OK, expected_payload) - + result = await command_handler.send(b"test_command", [EventType.OK]) - + mock_connection.send.assert_called_once_with(b"test_command") assert result.type == EventType.OK assert result.payload == expected_payload + async def test_send_timeout(command_handler, mock_connection, mock_dispatcher): mock_dispatcher.wait_for_event.side_effect = asyncio.TimeoutError - + result = await command_handler.send(b"test_command", [EventType.OK], timeout=0.1) assert result.type == EventType.ERROR assert result.payload == {"reason": "timeout"} + # Destination validation tests async def test_validate_destination_bytes(command_handler, mock_connection): dst = b"123456789012" # 12 bytes await command_handler.send_msg(dst, "test message") - + assert mock_connection.send.call_args[0][0].startswith(b"\x02\x00\x00") assert b"123456" in mock_connection.send.call_args[0][0] + async def test_validate_destination_hex_string(command_handler, mock_connection): - dst = "0123456789abcdef" + dst = "0123456789abcdef" await command_handler.send_msg(dst, "test message") - + assert mock_connection.send.call_args[0][0].startswith(b"\x02\x00\x00") assert b"\x01\x23\x45\x67\x89\xab" in mock_connection.send.call_args[0][0] + async def test_validate_destination_contact_object(command_handler, mock_connection): dst = {"public_key": "0123456789abcdef", "adv_name": "Test Contact"} await command_handler.send_msg(dst, "test message") - + assert mock_connection.send.call_args[0][0].startswith(b"\x02\x00\x00") assert b"\x01\x23\x45\x67\x89\xab" in mock_connection.send.call_args[0][0] + # Command tests async def test_send_login(command_handler, mock_connection): await command_handler.send_login("0123456789abcdef", "password") - + assert mock_connection.send.call_args[0][0].startswith(b"\x1a") assert b"\x01\x23\x45\x67\x89\xab" in mock_connection.send.call_args[0][0] assert b"password" in mock_connection.send.call_args[0][0] + async def test_send_msg(command_handler, mock_connection): await command_handler.send_msg("0123456789abcdef", "hello") - + assert mock_connection.send.call_args[0][0].startswith(b"\x02\x00\x00") assert b"\x01\x23\x45\x67\x89\xab" in mock_connection.send.call_args[0][0] assert b"hello" in mock_connection.send.call_args[0][0] + async def test_send_cmd(command_handler, mock_connection): await command_handler.send_cmd("0123456789abcdef", "test_cmd") - + assert mock_connection.send.call_args[0][0].startswith(b"\x02\x01\x00") assert b"\x01\x23\x45\x67\x89\xab" in mock_connection.send.call_args[0][0] assert b"test_cmd" in mock_connection.send.call_args[0][0] + # Device settings tests async def test_set_name(command_handler, mock_connection): await command_handler.set_name("Test Device") - + assert mock_connection.send.call_args[0][0].startswith(b"\x08") assert b"Test Device" in mock_connection.send.call_args[0][0] + async def test_set_coords(command_handler, mock_connection): await command_handler.set_coords(37.7749, -122.4194) - + assert mock_connection.send.call_args[0][0].startswith(b"\x0e") # Could add more detailed assertions for the byte encoding + async def test_send_appstart(command_handler, mock_connection): await command_handler.send_appstart() assert mock_connection.send.call_args[0][0].startswith(b"\x01\x03") assert b"mccli" in mock_connection.send.call_args[0][0] + async def test_send_device_query(command_handler, mock_connection): await command_handler.send_device_query() assert mock_connection.send.call_args[0][0].startswith(b"\x16\x03") + async def test_send_advert(command_handler, mock_connection): # Test without flood await command_handler.send_advert(flood=False) assert mock_connection.send.call_args[0][0] == b"\x07" - + # Test with flood mock_connection.reset_mock() await command_handler.send_advert(flood=True) assert mock_connection.send.call_args[0][0] == b"\x07\x01" + async def test_reboot(command_handler, mock_connection): await command_handler.reboot() assert mock_connection.send.call_args[0][0].startswith(b"\x13reboot") + async def test_get_bat(command_handler, mock_connection): await command_handler.get_bat() assert mock_connection.send.call_args[0][0].startswith(b"\x14") + async def test_get_time(command_handler, mock_connection): await command_handler.get_time() assert mock_connection.send.call_args[0][0].startswith(b"\x05") + async def test_set_time(command_handler, mock_connection): timestamp = 1620000000 # Example timestamp await command_handler.set_time(timestamp) assert mock_connection.send.call_args[0][0].startswith(b"\x06") + async def test_set_tx_power(command_handler, mock_connection): await command_handler.set_tx_power(20) assert mock_connection.send.call_args[0][0].startswith(b"\x0c") + async def test_get_contacts(command_handler, mock_connection): await command_handler.get_contacts() assert mock_connection.send.call_args[0][0].startswith(b"\x04") + async def test_reset_path(command_handler, mock_connection): dst = "0123456789abcdef" await command_handler.reset_path(dst) - assert mock_connection.send.call_args[0][0].startswith(b"\x0D") + assert mock_connection.send.call_args[0][0].startswith(b"\x0d") assert b"\x01\x23\x45\x67\x89\xab" in mock_connection.send.call_args[0][0] + async def test_share_contact(command_handler, mock_connection): dst = "0123456789abcdef" await command_handler.share_contact(dst) assert mock_connection.send.call_args[0][0].startswith(b"\x10") assert b"\x01\x23\x45\x67\x89\xab" in mock_connection.send.call_args[0][0] + async def test_export_contact(command_handler, mock_connection): # Test exporting all contacts await command_handler.export_contact() assert mock_connection.send.call_args[0][0] == b"\x11" - + # Test exporting specific contact mock_connection.reset_mock() dst = "0123456789abcdef" @@ -192,20 +223,23 @@ async def test_export_contact(command_handler, mock_connection): assert mock_connection.send.call_args[0][0].startswith(b"\x11") assert b"\x01\x23\x45\x67\x89\xab" in mock_connection.send.call_args[0][0] + async def test_remove_contact(command_handler, mock_connection): dst = "0123456789abcdef" await command_handler.remove_contact(dst) assert mock_connection.send.call_args[0][0].startswith(b"\x0f") assert b"\x01\x23\x45\x67\x89\xab" in mock_connection.send.call_args[0][0] + async def test_get_msg(command_handler, mock_connection): await command_handler.get_msg() - assert mock_connection.send.call_args[0][0].startswith(b"\x0A") - + assert mock_connection.send.call_args[0][0].startswith(b"\x0a") + # Test with custom timeout mock_connection.reset_mock() await command_handler.get_msg(timeout=5.0) - assert mock_connection.send.call_args[0][0].startswith(b"\x0A") + assert mock_connection.send.call_args[0][0].startswith(b"\x0a") + async def test_send_logout(command_handler, mock_connection): dst = "0123456789abcdef" @@ -213,65 +247,74 @@ async def test_send_logout(command_handler, mock_connection): assert mock_connection.send.call_args[0][0].startswith(b"\x1d") assert b"\x01\x23\x45\x67\x89\xab" in mock_connection.send.call_args[0][0] + async def test_send_statusreq(command_handler, mock_connection): dst = "0123456789abcdef" await command_handler.send_statusreq(dst) assert mock_connection.send.call_args[0][0].startswith(b"\x1b") assert b"\x01\x23\x45\x67\x89\xab" in mock_connection.send.call_args[0][0] + async def test_send_trace(command_handler, mock_connection): # Test with minimal parameters await command_handler.send_trace() first_call = mock_connection.send.call_args[0][0] assert first_call.startswith(b"\x24") # 36 in decimal = 0x24 in hex - + # Test with all parameters mock_connection.reset_mock() await command_handler.send_trace( - auth_code=12345, - tag=67890, - flags=1, - path="01,23,45" + auth_code=12345, tag=67890, flags=1, path="01,23,45" ) second_call = mock_connection.send.call_args[0][0] assert second_call.startswith(b"\x24") -async def test_send_with_multiple_expected_events_returns_first_completed(command_handler, mock_connection, mock_dispatcher): + +async def test_send_with_multiple_expected_events_returns_first_completed( + command_handler, mock_connection, mock_dispatcher +): # Setup the dispatcher to return an ERROR event error_payload = {"reason": "command_failed"} - + async def simulate_error_event(*args, **kwargs): # Simulate an ERROR event being returned return Event(EventType.ERROR, error_payload) - + # Patch the wait_for_event method to return our simulated event mock_dispatcher.wait_for_event.side_effect = simulate_error_event - + # Call send with both OK and ERROR in the expected_events list, with OK first - result = await command_handler.send(b"test_command", [EventType.OK, EventType.ERROR]) - + result = await command_handler.send( + b"test_command", [EventType.OK, EventType.ERROR] + ) + # Verify the command was sent mock_connection.send.assert_called_once_with(b"test_command") - + # Verify that even though OK was listed first, the ERROR event was returned assert result.type == EventType.ERROR assert result.payload == error_payload + # Channel command tests async def test_get_channel(command_handler, mock_connection): await command_handler.get_channel(3) assert mock_connection.send.call_args[0][0] == b"\x1f\x03" + async def test_set_channel(command_handler, mock_connection): channel_secret = bytes(range(16)) # 16 bytes: 0x00, 0x01, ..., 0x0f await command_handler.set_channel(5, "MyChannel", channel_secret) - + expected_data = b"\x20\x05" # CMD_SET_CHANNEL + channel_idx=5 - expected_data += b"MyChannel" + b"\x00" * (32 - len("MyChannel")) # 32-byte padded name + expected_data += b"MyChannel" + b"\x00" * ( + 32 - len("MyChannel") + ) # 32-byte padded name expected_data += channel_secret # 16-byte secret - + assert mock_connection.send.call_args[0][0] == expected_data + async def test_set_channel_invalid_secret_length(command_handler): with pytest.raises(ValueError, match="Channel secret must be exactly 16 bytes"): - await command_handler.set_channel(1, "Test", b"tooshort") \ No newline at end of file + await command_handler.set_channel(1, "Test", b"tooshort") diff --git a/tests/unit/test_events.py b/tests/unit/test_events.py index d3dabb7..0dd272b 100644 --- a/tests/unit/test_events.py +++ b/tests/unit/test_events.py @@ -1,127 +1,129 @@ import pytest import asyncio -from unittest.mock import MagicMock, AsyncMock +from unittest.mock import MagicMock from meshcore.events import EventDispatcher, EventType, Event pytestmark = pytest.mark.asyncio + @pytest.fixture def dispatcher(): return EventDispatcher() + async def test_subscribe_with_attribute_filter(dispatcher): callback = MagicMock() - + # Subscribe with attribute filters - subscription = dispatcher.subscribe( - EventType.MSG_SENT, + dispatcher.subscribe( + EventType.MSG_SENT, callback, - attribute_filters={"type": 1, "expected_ack": "1234"} + attribute_filters={"type": 1, "expected_ack": "1234"}, ) - + # Start the dispatcher await dispatcher.start() - + try: # Dispatch event that should NOT match (wrong type) - await dispatcher.dispatch(Event( - EventType.MSG_SENT, - {"some": "data"}, - {"type": 2, "expected_ack": "1234"} - )) + await dispatcher.dispatch( + Event( + EventType.MSG_SENT, + {"some": "data"}, + {"type": 2, "expected_ack": "1234"}, + ) + ) await asyncio.sleep(0.1) # Allow processing - + # Callback should NOT have been called assert callback.call_count == 0 - + # Dispatch event that should match all filters - await dispatcher.dispatch(Event( - EventType.MSG_SENT, - {"some": "data"}, - {"type": 1, "expected_ack": "1234"} - )) + await dispatcher.dispatch( + Event( + EventType.MSG_SENT, + {"some": "data"}, + {"type": 1, "expected_ack": "1234"}, + ) + ) await asyncio.sleep(0.1) # Allow processing - + # Callback should have been called once assert callback.call_count == 1 - + finally: await dispatcher.stop() + async def test_wait_for_event_with_attribute_filter(dispatcher): await dispatcher.start() - + try: future_event = asyncio.create_task( dispatcher.wait_for_event( - EventType.ACK, - attribute_filters={"code": "1234"}, - timeout=3.0 + EventType.ACK, attribute_filters={"code": "1234"}, timeout=3.0 ) ) - + await asyncio.sleep(0.1) - - await dispatcher.dispatch(Event( - EventType.ACK, - {"some": "data"}, - {"code": "5678"} - )) - + + await dispatcher.dispatch( + Event(EventType.ACK, {"some": "data"}, {"code": "5678"}) + ) + await asyncio.sleep(0.1) - - await dispatcher.dispatch(Event( - EventType.ACK, - {"ack": "data"}, - {"code": "1234"} - )) - + + await dispatcher.dispatch( + Event(EventType.ACK, {"ack": "data"}, {"code": "1234"}) + ) + result = await asyncio.wait_for(future_event, 3.0) - + assert result is not None assert result.type == EventType.ACK assert result.attributes["code"] == "1234" assert result.payload == {"ack": "data"} - + finally: await dispatcher.stop() + async def test_wait_for_event_timeout_with_filter(dispatcher): await dispatcher.start() - + try: # Wait for an event that won't arrive result = await dispatcher.wait_for_event( - EventType.ACK, - attribute_filters={"code": "1234"}, - timeout=0.1 + EventType.ACK, attribute_filters={"code": "1234"}, timeout=0.1 ) - + # Should get None due to timeout assert result is None - + finally: await dispatcher.stop() + async def test_event_init_with_kwargs(): # Test creating an event with keyword attributes event = Event(EventType.ACK, {"data": "value"}, code="1234", status="ok") - + assert event.type == EventType.ACK assert event.payload == {"data": "value"} assert event.attributes == {"code": "1234", "status": "ok"} + async def test_channel_info_event(): # Test CHANNEL_INFO event type channel_payload = { "channel_idx": 3, "channel_name": "TestChannel", - "channel_secret": b"\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10" + "channel_secret": b"\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10", } - + event = Event(EventType.CHANNEL_INFO, channel_payload) - + assert event.type == EventType.CHANNEL_INFO assert event.payload["channel_idx"] == 3 assert event.payload["channel_name"] == "TestChannel" - assert len(event.payload["channel_secret"]) == 16 \ No newline at end of file + assert len(event.payload["channel_secret"]) == 16