diff --git a/src/meshcore/ble_cx.py b/src/meshcore/ble_cx.py index 517ef1b..d0aa14d 100644 --- a/src/meshcore/ble_cx.py +++ b/src/meshcore/ble_cx.py @@ -49,7 +49,8 @@ class BLEConnection: if self.client: logger.debug("Using pre-configured BleakClient.") # If a client is already provided, ensure its disconnect callback is set - self.client._disconnected_callback = self.handle_disconnect + assert isinstance(self.client, BleakClient) + self.client.set_disconnected_callback(self.handle_disconnect) self.address = self.client.address else: diff --git a/src/meshcore/events.py b/src/meshcore/events.py index fc81f80..001992d 100644 --- a/src/meshcore/events.py +++ b/src/meshcore/events.py @@ -1,4 +1,5 @@ from enum import Enum +import inspect import logging from math import log from typing import Any, Dict, Optional, Callable, List, Union @@ -133,6 +134,7 @@ class EventDispatcher: while self.running: event = await self.queue.get() logger.debug(f"Dispatching event: {event.type}, {event.payload}, {event.attributes}") + for subscription in self.subscriptions.copy(): # Check if event type matches if subscription.event_type is None or subscription.event_type == event.type: @@ -142,14 +144,23 @@ class EventDispatcher: if not all(event.attributes.get(key) == value for key, value in subscription.attribute_filters.items()): continue - try: - result = subscription.callback(event.clone()) - if asyncio.iscoroutine(result): - await result - except Exception as e: - print(f"Error in event handler: {e}") + + # Fire the call back asychronously + asyncio.create_task(self._execute_callback(subscription.callback, event.clone())) self.queue.task_done() + + async def _execute_callback(self, callback, event): + """Execute a callback with proper error handling""" + try: + if asyncio.iscoroutinefunction(callback): + await callback(event) + else: + result = callback(event) + if inspect.iscoroutine(result): + await result + except Exception as e: + logger.error(f"Error in event handler for {event.type}: {e}", exc_info=True) async def start(self): if not self.running: diff --git a/src/meshcore/serial_cx.py b/src/meshcore/serial_cx.py index 172baf4..45c0034 100644 --- a/src/meshcore/serial_cx.py +++ b/src/meshcore/serial_cx.py @@ -20,6 +20,7 @@ class SerialConnection: self.inframe = b"" self._disconnect_callback = None self.cx_dly = cx_dly + self._connected_event = asyncio.Event() class MCSerialClientProtocol(asyncio.Protocol): def __init__(self, cx): @@ -30,12 +31,14 @@ class SerialConnection: logger.debug('port opened') if isinstance(transport, serial_asyncio.SerialTransport) and transport.serial: transport.serial.rts = False # You can manipulate Serial object via transport + self.cx._connected_event.set() def data_received(self, data): self.cx.handle_rx(data) def connection_lost(self, exc): logger.debug('Serial port closed') + self.cx._connected_event.clear() if self.cx._disconnect_callback: asyncio.create_task(self.cx._disconnect_callback("serial_disconnect")) @@ -49,12 +52,14 @@ class SerialConnection: """ Connects to the device """ + self._connected_event.clear() + loop = asyncio.get_running_loop() await serial_asyncio.create_serial_connection( loop, lambda: self.MCSerialClientProtocol(self), self.port, baudrate=self.baudrate) - await asyncio.sleep(self.cx_dly) # wait for cx to establish + await self._connected_event.wait() logger.info("Serial Connection started") return self.port @@ -99,6 +104,7 @@ class SerialConnection: if self.transport: self.transport.close() self.transport = None + self._connected_event.clear() logger.debug("Serial Connection closed") def set_disconnect_callback(self, callback): diff --git a/tests/test_ble_connection.py b/tests/test_ble_connection.py index dc2a649..d15e9f7 100644 --- a/tests/test_ble_connection.py +++ b/tests/test_ble_connection.py @@ -44,6 +44,7 @@ class TestBLEConnection(unittest.TestCase): asyncio.run(ble_conn.send(data_to_send)) # Assert + assert(isinstance(ble_conn.rx_char, MagicMock)) ble_conn.rx_char.write_gatt_char.assert_called_once_with(ble_conn.rx_char, data_to_send, response=True) def _get_mock_bleak_client(self):