From dabc3437dd7b6aac75094543efc40a8add0d74a8 Mon Sep 17 00:00:00 2001 From: Alex Wolden Date: Mon, 30 Jun 2025 15:50:59 -0700 Subject: [PATCH] Add better connection management --- README.md | 37 +++++++ examples/connection_events_example.py | 87 +++++++++++++++ src/meshcore/__init__.py | 1 + src/meshcore/ble_cx.py | 31 ++++-- src/meshcore/connection_manager.py | 149 ++++++++++++++++++++++++++ src/meshcore/events.py | 4 + src/meshcore/meshcore.py | 51 ++++++--- src/meshcore/serial_cx.py | 11 +- src/meshcore/tcp_cx.py | 30 +++++- 9 files changed, 370 insertions(+), 31 deletions(-) create mode 100644 examples/connection_events_example.py create mode 100644 src/meshcore/connection_manager.py diff --git a/README.md b/README.md index 1d8cf84..6c1aff3 100644 --- a/README.md +++ b/README.md @@ -111,6 +111,43 @@ meshcore = await MeshCore.create_ble("12:34:56:78:90:AB") meshcore = await MeshCore.create_tcp("192.168.1.100", 4000) ``` +#### Auto-Reconnect and Connection Events + +Enable automatic reconnection when connections are lost: + +```python +# Enable auto-reconnect with custom retry limits +meshcore = await MeshCore.create_tcp( + "192.168.1.100", 4000, + auto_reconnect=True, + max_reconnect_attempts=5 +) + +# Subscribe to connection events +async def on_connected(event): + print(f"Connected: {event.payload}") + if event.payload.get('reconnected'): + print("Successfully reconnected!") + +async def on_disconnected(event): + print(f"Disconnected: {event.payload['reason']}") + if event.payload.get('max_attempts_exceeded'): + print("Max reconnection attempts exceeded") + +meshcore.subscribe(EventType.CONNECTED, on_connected) +meshcore.subscribe(EventType.DISCONNECTED, on_disconnected) + +# Check connection status +if meshcore.is_connected: + print("Device is currently connected") +``` + +**Auto-reconnect features:** +- Exponential backoff (1s, 2s, 4s, 8s max delay) +- Configurable retry limits (default: 3 attempts) +- Automatic disconnect detection (especially useful for TCP connections) +- Connection events with detailed information + ### Using Commands (Synchronous Style) Send commands and wait for responses: diff --git a/examples/connection_events_example.py b/examples/connection_events_example.py new file mode 100644 index 0000000..137adf4 --- /dev/null +++ b/examples/connection_events_example.py @@ -0,0 +1,87 @@ +""" +Example demonstrating connection events and auto-reconnect functionality. +""" +import asyncio +import logging +import sys +from meshcore import MeshCore +from meshcore.events import EventType + +logging.basicConfig(level=logging.DEBUG) + +async def main(): + mc = None + # Example with auto-reconnect enabled + try: + # mc = await MeshCore.create_serial( + # port="/dev/cu.usbmodem1101", + # auto_reconnect=True, + # max_reconnect_attempts=3, + # debug=True + # ) + + # mc = await MeshCore.create_tcp( + # host="192.168.1.22", + # port=5000, + # auto_reconnect=True, + # max_reconnect_attempts=sys.maxsize, + # debug=True + # ) + + mc = await MeshCore.create_ble( + address="92849669", + auto_reconnect=True, + max_reconnect_attempts=3, + debug=True + ) + + # Subscribe to connection events + async def on_connected(event): + print(f"āœ… Connected! Info: {event.payload}") + if event.payload.get('reconnected'): + print("šŸ”„ This was a reconnection!") + + async def on_disconnected(event): + print(f"āŒ Disconnected! Reason: {event.payload.get('reason')}") + if event.payload.get('max_attempts_exceeded'): + print("āš ļø Max reconnection attempts exceeded") + + mc.subscribe(EventType.CONNECTED, on_connected) + mc.subscribe(EventType.DISCONNECTED, on_disconnected) + + # Check connection status + + print("\nšŸ“± Disconnect your device now to test auto-reconnect...") + print("Press Ctrl+C to exit") + + # Keep running and periodically test the connection + while True: + await asyncio.sleep(2) + print(f"Connected: {mc.is_connected}") + if mc.is_connected: + try: + print("šŸ”„ Testing connection by getting battery...") + result = await mc.commands.get_bat() + + if result.type == EventType.ERROR: + print(f"āŒ Error getting battery: {result.payload}") + else: + print("āœ… Connection test successfeul") + except Exception as e: + print(f"āŒ Connection test failed: {e}") + # This should trigger the disconnect detection + else: + print("ā³ Waiting for reconnection...") + + except KeyboardInterrupt: + print("\nšŸ›‘ Exiting...") + except ConnectionError as e: + print(f"āŒ Failed to connect: {e}") + finally: + if mc is not None: + await mc.disconnect() + print(f"Connected after disconnect: {mc.is_connected}") + print(f"Connected after disconnect: {mc.is_connected}") + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/src/meshcore/__init__.py b/src/meshcore/__init__.py index 0171409..f64d872 100644 --- a/src/meshcore/__init__.py +++ b/src/meshcore/__init__.py @@ -5,6 +5,7 @@ logging.basicConfig(level=logging.INFO) from meshcore.events import EventType from meshcore.meshcore import MeshCore, logger +from meshcore.connection_manager import ConnectionManager from meshcore.tcp_cx import TCPConnection from meshcore.ble_cx import BLEConnection from meshcore.serial_cx import SerialConnection diff --git a/src/meshcore/ble_cx.py b/src/meshcore/ble_cx.py index 72e363f..e68d7b4 100644 --- a/src/meshcore/ble_cx.py +++ b/src/meshcore/ble_cx.py @@ -21,9 +21,10 @@ class BLEConnection: def __init__(self, address): """ Constructor : specify address """ self.address = address + self._user_provided_address = address self.client = None self.rx_char = None - self.mc = None + self._disconnect_callback = None async def connect(self): """ @@ -31,6 +32,7 @@ class BLEConnection: Returns : the address used for connection """ + logger.debug(f"Connecting existing connection: {self.client} with address {self.address}") def match_meshcore_device(_: BLEDevice, adv: AdvertisementData): """ Filter to mach MeshCore devices """ if not adv.local_name is None\ @@ -39,20 +41,20 @@ class BLEConnection: return True return False - if self.address is None or self.address == "" or len(self.address.split(":")) != 6 : + if self.address is None or self.address == "" or len(self.address.split(":")) != 6: scanner = BleakScanner() logger.info("Scanning for devices") device = await scanner.find_device_by_filter(match_meshcore_device) - if device is None : + if device is None: return None logger.info(f"Found device : {device}") - self.client = BleakClient(device) + self.client = BleakClient(device, disconnected_callback=self.handle_disconnect) self.address = self.client.address else: - self.client = BleakClient(self.address) + self.client = BleakClient(self.address, disconnected_callback=self.handle_disconnect) try: - await self.client.connect(disconnected_callback=self.handle_disconnect) + await self.client.connect() except BleakDeviceNotFoundError: return None except TimeoutError: @@ -69,12 +71,19 @@ class BLEConnection: logger.info("BLE Connection started") return self.address - def handle_disconnect(self, _: BleakClient): + def handle_disconnect(self, client: BleakClient): """ Callback to handle disconnection """ - logger.info("Device was disconnected, goodbye.") - # cancelling all tasks effectively ends the program - for task in asyncio.all_tasks(): - task.cancel() + logger.debug(f"BLE device disconnected: {client.address} (is_connected: {client.is_connected})") + # Reset the address we found to what user specified + # this allows to reconnect to the same device + self.address = self._user_provided_address + + if self._disconnect_callback: + asyncio.create_task(self._disconnect_callback("ble_disconnect")) + + def set_disconnect_callback(self, callback): + """Set callback to handle disconnections.""" + self._disconnect_callback = callback def set_reader(self, reader) : self.reader = reader diff --git a/src/meshcore/connection_manager.py b/src/meshcore/connection_manager.py new file mode 100644 index 0000000..fd5984f --- /dev/null +++ b/src/meshcore/connection_manager.py @@ -0,0 +1,149 @@ +""" +Connection manager that orchestrates reconnection logic for any connection type. +""" +import asyncio +import logging +from typing import Optional, Any, Callable, Protocol +from .events import Event, EventType + +logger = logging.getLogger("meshcore") + +class ConnectionProtocol(Protocol): + """Protocol defining the interface that connection classes must implement.""" + + async def connect(self) -> Optional[Any]: + """Connect and return connection info, or None if failed.""" + ... + + async def disconnect(self): + """Disconnect from the device/server.""" + ... + + async def send(self, data): + """Send data through the connection.""" + ... + + def set_reader(self, reader): + """Set the message reader.""" + ... + + +class ConnectionManager: + """Manages connection lifecycle with auto-reconnect and event emission.""" + + def __init__(self, connection: ConnectionProtocol, event_dispatcher=None, + auto_reconnect: bool = False, max_reconnect_attempts: int = 3): + self.connection = connection + self.event_dispatcher = event_dispatcher + self.auto_reconnect = auto_reconnect + self.max_reconnect_attempts = max_reconnect_attempts + + self._reconnect_attempts = 0 + self._is_connected = False + self._reconnect_task = None + self._disconnect_callback: Optional[Callable] = None + + def set_disconnect_callback(self, callback: Callable): + """Set a callback to be called when disconnection is detected.""" + self._disconnect_callback = callback + + async def connect(self) -> Optional[Any]: + """Connect with event handling and state management.""" + result = await self.connection.connect() + + if result is not None: + self._is_connected = True + self._reconnect_attempts = 0 + await self._emit_event(EventType.CONNECTED, {"connection_info": result}) + logger.debug(f"Connected successfully: {result}") + else: + logger.debug("Connection failed") + + return result + + async def disconnect(self): + """Disconnect with proper cleanup.""" + if self._reconnect_task: + self._reconnect_task.cancel() + try: + await self._reconnect_task + except asyncio.CancelledError: + pass + self._reconnect_task = None + + if self._is_connected: + await self.connection.disconnect() + self._is_connected = False + await self._emit_event(EventType.DISCONNECTED, {"reason": "manual_disconnect"}) + + async def handle_disconnect(self, reason: str = "unknown"): + """Handle unexpected disconnections with optional auto-reconnect.""" + if not self._is_connected: + return + + self._is_connected = False + logger.debug(f"Connection lost: {reason}") + + if self.auto_reconnect and self._reconnect_attempts < self.max_reconnect_attempts: + self._reconnect_task = asyncio.create_task(self._attempt_reconnect()) + else: + await self._emit_event(EventType.DISCONNECTED, { + "reason": reason, + "reconnect_failed": self._reconnect_attempts >= self.max_reconnect_attempts + }) + + async def _attempt_reconnect(self): + """Attempt to reconnect with flat delay.""" + logger.debug(f"Attempting reconnection ({self._reconnect_attempts + 1}/{self.max_reconnect_attempts})") + self._reconnect_attempts += 1 + + # Flat 1 second delay for all attempts + await asyncio.sleep(1) + + try: + result = await self.connection.connect() + if result is not None: + self._is_connected = True + self._reconnect_attempts = 0 + await self._emit_event(EventType.CONNECTED, { + "connection_info": result, + "reconnected": True + }) + logger.debug(f"Reconnected successfully") + else: + # Reconnection failed, try again if we haven't exceeded max attempts + if self._reconnect_attempts < self.max_reconnect_attempts: + self._reconnect_task = asyncio.create_task(self._attempt_reconnect()) + else: + await self._emit_event(EventType.DISCONNECTED, { + "reason": "reconnect_failed", + "max_attempts_exceeded": True + }) + except Exception as e: + logger.debug(f"Reconnection attempt failed: {e}") + if self._reconnect_attempts < self.max_reconnect_attempts: + self._reconnect_task = asyncio.create_task(self._attempt_reconnect()) + else: + await self._emit_event(EventType.DISCONNECTED, { + "reason": f"reconnect_error: {e}", + "max_attempts_exceeded": True + }) + + async def _emit_event(self, event_type: EventType, payload: dict): + """Emit connection events if dispatcher is available.""" + if self.event_dispatcher: + event = Event(event_type, payload) + await self.event_dispatcher.dispatch(event) + + @property + def is_connected(self) -> bool: + """Check if the connection is active.""" + return self._is_connected + + async def send(self, data): + """Send data through the managed connection.""" + return await self.connection.send(data) + + def set_reader(self, reader): + """Set the message reader on the underlying connection.""" + self.connection.set_reader(reader) \ No newline at end of file diff --git a/src/meshcore/events.py b/src/meshcore/events.py index c498dbc..bedd36d 100644 --- a/src/meshcore/events.py +++ b/src/meshcore/events.py @@ -40,6 +40,10 @@ class EventType(Enum): # Command response types OK = "command_ok" ERROR = "command_error" + + # Connection events + CONNECTED = "connected" + DISCONNECTED = "disconnected" @dataclass diff --git a/src/meshcore/meshcore.py b/src/meshcore/meshcore.py index ad55b34..555f7f5 100644 --- a/src/meshcore/meshcore.py +++ b/src/meshcore/meshcore.py @@ -5,6 +5,7 @@ from typing import Optional, Dict, Any, Union from .events import EventDispatcher, EventType from .reader import MessageReader from .commands import CommandHandler +from .connection_manager import ConnectionManager from .ble_cx import BLEConnection from .tcp_cx import TCPConnection from .serial_cx import SerialConnection @@ -16,9 +17,14 @@ class MeshCore: """ Interface to a MeshCore device """ - def __init__(self, cx, debug=False, default_timeout=None): - self.cx = cx + def __init__(self, cx, debug=False, default_timeout=None, auto_reconnect=False, max_reconnect_attempts=3): + # Wrap connection with ConnectionManager self.dispatcher = EventDispatcher() + self.connection_manager = ConnectionManager( + cx, self.dispatcher, auto_reconnect, max_reconnect_attempts + ) + self.cx = self.connection_manager # For backward compatibility + self._reader = MessageReader(self.dispatcher) self.commands = CommandHandler(default_timeout=default_timeout) @@ -29,7 +35,7 @@ class MeshCore: logger.setLevel(logging.INFO) # Set up connections - self.commands.set_connection(cx) + self.commands.set_connection(self.connection_manager) # Set the dispatcher in the command handler self.commands.set_dispatcher(self.dispatcher) @@ -43,47 +49,54 @@ class MeshCore: # Set up event subscriptions to track data self._setup_data_tracking() - cx.set_reader(self._reader) + self.connection_manager.set_reader(self._reader) + + # Set up disconnect callback + cx.set_disconnect_callback(self.connection_manager.handle_disconnect) @classmethod - async def create_tcp(cls, host: str, port: int, debug: bool = False, default_timeout=None) -> 'MeshCore': + async def create_tcp(cls, host: str, port: int, debug: bool = False, default_timeout=None, + auto_reconnect: bool = False, max_reconnect_attempts: int = 3) -> 'MeshCore': """Create and connect a MeshCore instance using TCP connection""" connection = TCPConnection(host, port) - await connection.connect() - mc = cls(connection, debug=debug, default_timeout=default_timeout) + mc = cls(connection, debug=debug, default_timeout=default_timeout, + auto_reconnect=auto_reconnect, max_reconnect_attempts=max_reconnect_attempts) await mc.connect() return mc @classmethod - async def create_serial(cls, port: str, baudrate: int = 115200, debug: bool = False, default_timeout=None) -> 'MeshCore': + async def create_serial(cls, port: str, baudrate: int = 115200, debug: bool = False, default_timeout=None, + auto_reconnect: bool = False, max_reconnect_attempts: int = 3) -> 'MeshCore': """Create and connect a MeshCore instance using serial connection""" connection = SerialConnection(port, baudrate) - await connection.connect() await asyncio.sleep(0.1) # Time for transport to establish - mc = cls(connection, debug=debug, default_timeout=default_timeout) + mc = cls(connection, debug=debug, default_timeout=default_timeout, + auto_reconnect=auto_reconnect, max_reconnect_attempts=max_reconnect_attempts) await mc.connect() return mc @classmethod - async def create_ble(cls, address: Optional[str] = None, debug: bool = False, default_timeout=None) -> 'MeshCore': + async def create_ble(cls, address: Optional[str] = None, debug: bool = False, default_timeout=None, + auto_reconnect: bool = False, max_reconnect_attempts: int = 3) -> 'MeshCore': """Create and connect a MeshCore instance using BLE connection If address is None, it will scan for and connect to the first available MeshCore device. """ connection = BLEConnection(address) - result = await connection.connect() - if result is None: - raise ConnectionError("Failed to connect to BLE device") - mc = cls(connection, debug=debug, default_timeout=default_timeout) + mc = cls(connection, debug=debug, default_timeout=default_timeout, + auto_reconnect=auto_reconnect, max_reconnect_attempts=max_reconnect_attempts) await mc.connect() return mc async def connect(self): await self.dispatcher.start() + result = await self.connection_manager.connect() + if result is None: + raise ConnectionError("Failed to connect to device") return await self.commands.send_appstart() async def disconnect(self): @@ -96,8 +109,7 @@ class MeshCore: await self.stop_auto_message_fetching() # Disconnect the connection object - if self.cx: - await self.cx.disconnect() + await self.connection_manager.disconnect() def stop(self): """Synchronously stop the event dispatcher task""" @@ -195,6 +207,11 @@ class MeshCore: """Get the current device time""" return self._time + @property + def is_connected(self): + """Check if the connection is active""" + return self.connection_manager.is_connected + @property def default_timeout(self): """Get the default timeout for commands""" diff --git a/src/meshcore/serial_cx.py b/src/meshcore/serial_cx.py index 61931b5..05419bd 100644 --- a/src/meshcore/serial_cx.py +++ b/src/meshcore/serial_cx.py @@ -17,6 +17,7 @@ class SerialConnection: self.transport = None self.header = b"" self.inframe = b"" + self._disconnect_callback = None class MCSerialClientProtocol(asyncio.Protocol): def __init__(self, cx): @@ -32,7 +33,9 @@ class SerialConnection: self.cx.handle_rx(data) def connection_lost(self, exc): - logger.info('port closed') + logger.debug('Serial port closed') + if self.cx._disconnect_callback: + asyncio.create_task(self.cx._disconnect_callback("serial_disconnect")) def pause_writing(self): logger.debug('pause writing') @@ -93,4 +96,8 @@ class SerialConnection: if self.transport: self.transport.close() self.transport = None - logger.debug("Serial Connection closed") \ No newline at end of file + logger.debug("Serial Connection closed") + + def set_disconnect_callback(self, callback): + """Set callback to handle disconnections.""" + self._disconnect_callback = callback \ No newline at end of file diff --git a/src/meshcore/tcp_cx.py b/src/meshcore/tcp_cx.py index 8d0aff9..8551b9a 100644 --- a/src/meshcore/tcp_cx.py +++ b/src/meshcore/tcp_cx.py @@ -7,6 +7,9 @@ import logging # Get logger logger = logging.getLogger("meshcore") +# TCP disconnect detection threshold +TCP_DISCONNECT_THRESHOLD = 5 + class TCPConnection: def __init__(self, host, port): self.host = host @@ -16,6 +19,9 @@ class TCPConnection: self.frame_size = 0 self.header = b"" self.inframe = b"" + self._disconnect_callback = None + self._send_count = 0 + self._receive_count = 0 class MCClientProtocol(asyncio.Protocol): def __init__(self, cx): @@ -23,17 +29,23 @@ class TCPConnection: def connection_made(self, transport): self.cx.transport = transport + # Reset counters on new connection + self.cx._send_count = 0 + self.cx._receive_count = 0 logger.debug('connection established') def data_received(self, data): logger.debug('data received') + self.cx._receive_count += 1 self.cx.handle_rx(data) def error_received(self, exc): logger.error(f'Error received: {exc}') def connection_lost(self, exc): - logger.info('The server closed the connection') + logger.debug('TCP server closed the connection') + if self.cx._disconnect_callback: + asyncio.create_task(self.cx._disconnect_callback("tcp_disconnect")) async def connect(self): """ @@ -80,7 +92,19 @@ class TCPConnection: async def send(self, data): if not self.transport: logger.error("Transport not connected, cannot send data") + if self._disconnect_callback: + await self._disconnect_callback("tcp_transport_lost") return + + self._send_count += 1 + + # Check if we've sent packets without any responses + if self._send_count - self._receive_count >= TCP_DISCONNECT_THRESHOLD: + logger.debug(f"TCP disconnect detected: sent {self._send_count}, received {self._receive_count}") + if self._disconnect_callback: + await self._disconnect_callback("tcp_no_response") + return + size = len(data) pkt = b"\x3c" + size.to_bytes(2, byteorder="little") + data logger.debug(f"sending pkt : {pkt}") @@ -92,3 +116,7 @@ class TCPConnection: self.transport.close() self.transport = None logger.debug("TCP Connection closed") + + def set_disconnect_callback(self, callback): + """Set callback to handle disconnections.""" + self._disconnect_callback = callback