diff --git a/src/meshcore/connection_manager.py b/src/meshcore/connection_manager.py index c95ec37..bcd09ff 100644 --- a/src/meshcore/connection_manager.py +++ b/src/meshcore/connection_manager.py @@ -4,14 +4,23 @@ Connection manager that orchestrates reconnection logic for any connection type. import asyncio import logging -from typing import Optional, Any, Callable, Protocol +from typing import Optional, Any, Awaitable, Callable, Protocol from .events import Event, EventType logger = logging.getLogger("meshcore") class ConnectionProtocol(Protocol): - """Protocol defining the interface that connection classes must implement.""" + """Protocol defining the interface that connection classes must implement. + + Return contract for connect(): + - On success: return a truthy value (typically an address string) + that identifies the connection. This value is included in the + CONNECTED event payload as ``connection_info``. + - On failure: return ``None`` (soft failure — triggers a retry in + ``_attempt_reconnect``) **or** raise an exception (hard failure — + also triggers a retry, logged as an error). + """ async def connect(self) -> Optional[Any]: """Connect and return connection info, or None if failed.""" @@ -39,11 +48,13 @@ class ConnectionManager: event_dispatcher=None, auto_reconnect: bool = False, max_reconnect_attempts: int = 3, + reconnect_callback: Optional[Callable[[], Awaitable[None]]] = None, ): self.connection = connection self.event_dispatcher = event_dispatcher self.auto_reconnect = auto_reconnect self.max_reconnect_attempts = max_reconnect_attempts + self._reconnect_callback = reconnect_callback self._reconnect_attempts = 0 self._is_connected = False @@ -109,45 +120,51 @@ class ConnectionManager: ) async def _attempt_reconnect(self): - """Attempt to reconnect with flat delay.""" - logger.debug( - f"Attempting reconnection ({self._reconnect_attempts + 1}/{self.max_reconnect_attempts})" - ) - self._reconnect_attempts += 1 + """Attempt to reconnect using an iterative loop. - # Flat 1 second delay for all attempts - await asyncio.sleep(1) + Runs as a single persistent task for the entire reconnect session. + Previous implementation used tail-recursion via create_task which + orphaned the running task reference — disconnect() could only cancel + the newest pointer, leaving earlier attempts in flight (F03). + """ + while self._reconnect_attempts < self.max_reconnect_attempts: + self._reconnect_attempts += 1 + logger.debug( + f"Attempting reconnection ({self._reconnect_attempts}/{self.max_reconnect_attempts})" + ) + + # Flat 1 second delay for all attempts + await asyncio.sleep(1) + + try: + result = await self.connection.connect() + if result is not None: + self._is_connected = True + self._reconnect_attempts = 0 + + # Invoke reconnect callback (e.g. send_appstart) if provided + if self._reconnect_callback is not None: + try: + await self._reconnect_callback() + except Exception as cb_err: + logger.warning( + f"Reconnect callback failed: {cb_err}" + ) - try: - result = await self.connection.connect() - if result is not None: - self._is_connected = True - self._reconnect_attempts = 0 - await self._emit_event( - EventType.CONNECTED, - {"connection_info": result, "reconnected": True}, - ) - logger.debug("Reconnected successfully") - else: - # Reconnection failed, try again if we haven't exceeded max attempts - if self._reconnect_attempts < self.max_reconnect_attempts: - self._reconnect_task = asyncio.create_task( - self._attempt_reconnect() - ) - else: await self._emit_event( - EventType.DISCONNECTED, - {"reason": "reconnect_failed", "max_attempts_exceeded": True}, + EventType.CONNECTED, + {"connection_info": result, "reconnected": True}, ) - except Exception as e: - logger.debug(f"Reconnection attempt failed: {e}") - if self._reconnect_attempts < self.max_reconnect_attempts: - self._reconnect_task = asyncio.create_task(self._attempt_reconnect()) - else: - await self._emit_event( - EventType.DISCONNECTED, - {"reason": f"reconnect_error: {e}", "max_attempts_exceeded": True}, - ) + logger.debug("Reconnected successfully") + return + except Exception as e: + logger.debug(f"Reconnection attempt failed: {e}") + + # All attempts exhausted + await self._emit_event( + EventType.DISCONNECTED, + {"reason": "reconnect_failed", "max_attempts_exceeded": True}, + ) async def _emit_event(self, event_type: EventType, payload: dict): """Emit connection events if dispatcher is available.""" diff --git a/src/meshcore/meshcore.py b/src/meshcore/meshcore.py index bdd0db3..af1cfa4 100644 --- a/src/meshcore/meshcore.py +++ b/src/meshcore/meshcore.py @@ -28,10 +28,17 @@ class MeshCore: auto_reconnect: bool = False, max_reconnect_attempts: int = 3, ): - # Wrap connection with ConnectionManager + # Wrap connection with ConnectionManager. + # The reconnect callback ensures send_appstart() runs after every + # transport-level reconnect, which is required by firmware to + # initialize the session (F02). self.dispatcher = EventDispatcher() self.connection_manager = ConnectionManager( - cx, self.dispatcher, auto_reconnect, max_reconnect_attempts + cx, + self.dispatcher, + auto_reconnect, + max_reconnect_attempts, + reconnect_callback=self._on_reconnect, ) self.cx = self.connection_manager # For backward compatibility @@ -174,6 +181,15 @@ class MeshCore: return None return mc + async def _on_reconnect(self): + """Callback invoked by ConnectionManager after a successful reconnect. + + Firmware requires CMD_APP_START after every transport-level connection + to initialize the session. MeshCore.connect() does this on the initial + connection; this callback ensures it also happens on reconnects (F02). + """ + await self.commands.send_appstart() + async def connect(self): await self.dispatcher.start() result = await self.connection_manager.connect() diff --git a/src/meshcore/tcp_cx.py b/src/meshcore/tcp_cx.py index 497c3b2..a28b625 100644 --- a/src/meshcore/tcp_cx.py +++ b/src/meshcore/tcp_cx.py @@ -59,10 +59,7 @@ class TCPConnection: ) logger.info("TCP Connection started") - future = asyncio.Future() - future.set_result(self.host) - - return future + return self.host def set_reader(self, reader): self.reader = reader diff --git a/tests/unit/test_connection_manager.py b/tests/unit/test_connection_manager.py new file mode 100644 index 0000000..d92cf06 --- /dev/null +++ b/tests/unit/test_connection_manager.py @@ -0,0 +1,294 @@ +"""Tests for reconnect-path fixes.""" + +import asyncio + +import pytest + +from meshcore.connection_manager import ConnectionManager +from meshcore.events import Event, EventDispatcher, EventType + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +class FakeConnection: + """Minimal stub that satisfies ConnectionProtocol.""" + + def __init__(self, connect_results=None): + """ + Args: + connect_results: iterator of return values for successive + connect() calls. ``None`` means soft failure; a string + means success; raising is also supported via sentinel. + """ + self._connect_results = list(connect_results or ["ok"]) + self._call_index = 0 + self.reader = None + + async def connect(self): + if self._call_index < len(self._connect_results): + result = self._connect_results[self._call_index] + self._call_index += 1 + else: + result = self._connect_results[-1] + if isinstance(result, Exception): + raise result + return result + + async def disconnect(self): + pass + + async def send(self, data): + pass + + def set_reader(self, reader): + self.reader = reader + + +class RaisingConnection(FakeConnection): + """Connection that raises on every connect() attempt.""" + + def __init__(self, exc=None): + super().__init__() + self._exc = exc or ConnectionError("boom") + + async def connect(self): + raise self._exc + + +class _EventCollector: + """Subscribes to all events and records them.""" + + def __init__(self, dispatcher: EventDispatcher): + self.events: list[Event] = [] + dispatcher.subscribe(None, self._on_event) + + async def _on_event(self, event: Event): + self.events.append(event) + + +# --------------------------------------------------------------------------- +# TCP connect() should return a plain value, not an asyncio.Future +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_tcp_connect_returns_plain_string(): + """TCPConnection.connect() returns self.host (a plain string), not an + asyncio.Future. We test indirectly via ConnectionManager — the + CONNECTED event payload should contain a plain string, not a Future + object.""" + conn = FakeConnection(connect_results=["10.0.0.1"]) + dispatcher = EventDispatcher() + await dispatcher.start() + try: + collector = _EventCollector(dispatcher) + mgr = ConnectionManager(conn, dispatcher) + + result = await mgr.connect() + + assert result == "10.0.0.1" + # Give the dispatcher a moment to deliver the event + await asyncio.sleep(0.05) + connected_events = [e for e in collector.events if e.type == EventType.CONNECTED] + assert len(connected_events) == 1 + payload = connected_events[0].payload + assert payload["connection_info"] == "10.0.0.1" + # The payload value must NOT be an asyncio.Future + assert not isinstance(payload["connection_info"], asyncio.Future) + finally: + await dispatcher.stop() + + +# --------------------------------------------------------------------------- +# Reconnect attempts must not compound (no tail-recursive create_task) +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_reconnect_loop_does_not_compound(): + """_attempt_reconnect must use a single iterative loop. After + max_reconnect_attempts failures, exactly that many connect() calls + should have been made — no exponential fan-out from orphaned tasks.""" + # All attempts fail (return None) + conn = FakeConnection(connect_results=[None, None, None, None]) + dispatcher = EventDispatcher() + await dispatcher.start() + try: + collector = _EventCollector(dispatcher) + mgr = ConnectionManager( + conn, dispatcher, auto_reconnect=True, max_reconnect_attempts=3, + ) + mgr._is_connected = True # simulate a live connection + + await mgr.handle_disconnect("test_disconnect") + # Wait for the reconnect loop to exhaust all attempts + # (3 attempts × 1s sleep each, but we can just await the task) + if mgr._reconnect_task: + await mgr._reconnect_task + + # Exactly 3 connect() calls should have been made + assert conn._call_index == 3 + + # A DISCONNECTED event with max_attempts_exceeded should have fired + await asyncio.sleep(0.05) + disconnected = [e for e in collector.events if e.type == EventType.DISCONNECTED] + assert len(disconnected) == 1 + assert disconnected[0].payload.get("max_attempts_exceeded") is True + finally: + await dispatcher.stop() + + +@pytest.mark.asyncio +async def test_disconnect_cancels_reconnect_loop(): + """disconnect() during an active reconnect loop must cancel the + single task cleanly — no orphaned tasks left running.""" + # Simulate a connection that always fails (returns None), giving us + # time to call disconnect() mid-loop. + conn = FakeConnection(connect_results=[None, None, None, None, None]) + dispatcher = EventDispatcher() + await dispatcher.start() + try: + mgr = ConnectionManager( + conn, dispatcher, auto_reconnect=True, max_reconnect_attempts=5, + ) + mgr._is_connected = True + + await mgr.handle_disconnect("test_disconnect") + + # Let the first attempt start (wait just past the 1s sleep) + await asyncio.sleep(1.2) + assert conn._call_index >= 1 # at least one attempt made + + # Now disconnect — should cancel the loop + await mgr.disconnect() + + assert mgr._reconnect_task is None + calls_at_cancel = conn._call_index + + # Wait a bit and confirm no more attempts happened + await asyncio.sleep(2) + assert conn._call_index == calls_at_cancel + finally: + await dispatcher.stop() + + +# --------------------------------------------------------------------------- +# reconnect_callback (send_appstart) is called after reconnect +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_reconnect_callback_called_after_reconnect(): + """When ConnectionManager reconnects successfully, the + reconnect_callback (e.g. send_appstart) must be invoked.""" + callback_called = [] + + async def fake_appstart(): + callback_called.append(True) + + # First connect() fails (None), second succeeds + conn = FakeConnection(connect_results=[None, "10.0.0.1"]) + dispatcher = EventDispatcher() + await dispatcher.start() + try: + mgr = ConnectionManager( + conn, dispatcher, + auto_reconnect=True, + max_reconnect_attempts=3, + reconnect_callback=fake_appstart, + ) + mgr._is_connected = True + + await mgr.handle_disconnect("test_disconnect") + if mgr._reconnect_task: + await mgr._reconnect_task + + assert len(callback_called) == 1 + finally: + await dispatcher.stop() + + +@pytest.mark.asyncio +async def test_reconnect_callback_failure_does_not_crash_loop(): + """If the reconnect_callback raises, the reconnect still counts as + successful (transport is up) — the callback failure is logged but + does not crash the loop or leave the manager in a broken state.""" + async def failing_callback(): + raise RuntimeError("appstart failed") + + # connect() succeeds on first attempt + conn = FakeConnection(connect_results=["10.0.0.1"]) + dispatcher = EventDispatcher() + await dispatcher.start() + try: + collector = _EventCollector(dispatcher) + mgr = ConnectionManager( + conn, dispatcher, + auto_reconnect=True, + max_reconnect_attempts=3, + reconnect_callback=failing_callback, + ) + mgr._is_connected = True + + await mgr.handle_disconnect("test_disconnect") + if mgr._reconnect_task: + await mgr._reconnect_task + + # Despite callback failure, CONNECTED event should have fired + await asyncio.sleep(0.05) + connected = [e for e in collector.events if e.type == EventType.CONNECTED] + assert len(connected) == 1 + assert mgr._is_connected is True + finally: + await dispatcher.stop() + + +# --------------------------------------------------------------------------- +# connect() returning None is a soft failure (BLE scan miss) +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_connect_none_is_soft_failure(): + """When connect() returns None (e.g. BLE scan found no device), + ConnectionManager.connect() should NOT set _is_connected and should + NOT emit a CONNECTED event.""" + conn = FakeConnection(connect_results=[None]) + dispatcher = EventDispatcher() + await dispatcher.start() + try: + collector = _EventCollector(dispatcher) + mgr = ConnectionManager(conn, dispatcher) + + result = await mgr.connect() + + assert result is None + assert mgr._is_connected is False + await asyncio.sleep(0.05) + connected = [e for e in collector.events if e.type == EventType.CONNECTED] + assert len(connected) == 0 + finally: + await dispatcher.stop() + + +@pytest.mark.asyncio +async def test_no_reconnect_callback_is_noop(): + """When no reconnect_callback is provided (backwards compat for + direct ConnectionManager users), reconnect should still work.""" + conn = FakeConnection(connect_results=["10.0.0.1"]) + dispatcher = EventDispatcher() + await dispatcher.start() + try: + mgr = ConnectionManager( + conn, dispatcher, + auto_reconnect=True, + max_reconnect_attempts=3, + # No reconnect_callback — default None + ) + mgr._is_connected = True + + await mgr.handle_disconnect("test_disconnect") + if mgr._reconnect_task: + await mgr._reconnect_task + + assert mgr._is_connected is True + finally: + await dispatcher.stop()