This commit is contained in:
mwolter805 2026-04-18 05:12:47 -07:00 committed by GitHub
commit def1ea906f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 367 additions and 43 deletions

View file

@ -4,14 +4,23 @@ Connection manager that orchestrates reconnection logic for any connection type.
import asyncio
import logging
from typing import Optional, Any, Callable, Protocol
from typing import Optional, Any, Awaitable, Callable, Protocol
from .events import Event, EventType
logger = logging.getLogger("meshcore")
class ConnectionProtocol(Protocol):
"""Protocol defining the interface that connection classes must implement."""
"""Protocol defining the interface that connection classes must implement.
Return contract for connect():
- On success: return a truthy value (typically an address string)
that identifies the connection. This value is included in the
CONNECTED event payload as ``connection_info``.
- On failure: return ``None`` (soft failure triggers a retry in
``_attempt_reconnect``) **or** raise an exception (hard failure
also triggers a retry, logged as an error).
"""
async def connect(self) -> Optional[Any]:
"""Connect and return connection info, or None if failed."""
@ -39,11 +48,13 @@ class ConnectionManager:
event_dispatcher=None,
auto_reconnect: bool = False,
max_reconnect_attempts: int = 3,
reconnect_callback: Optional[Callable[[], Awaitable[None]]] = None,
):
self.connection = connection
self.event_dispatcher = event_dispatcher
self.auto_reconnect = auto_reconnect
self.max_reconnect_attempts = max_reconnect_attempts
self._reconnect_callback = reconnect_callback
self._reconnect_attempts = 0
self._is_connected = False
@ -109,45 +120,51 @@ class ConnectionManager:
)
async def _attempt_reconnect(self):
"""Attempt to reconnect with flat delay."""
logger.debug(
f"Attempting reconnection ({self._reconnect_attempts + 1}/{self.max_reconnect_attempts})"
)
self._reconnect_attempts += 1
"""Attempt to reconnect using an iterative loop.
# Flat 1 second delay for all attempts
await asyncio.sleep(1)
Runs as a single persistent task for the entire reconnect session.
Previous implementation used tail-recursion via create_task which
orphaned the running task reference disconnect() could only cancel
the newest pointer, leaving earlier attempts in flight (F03).
"""
while self._reconnect_attempts < self.max_reconnect_attempts:
self._reconnect_attempts += 1
logger.debug(
f"Attempting reconnection ({self._reconnect_attempts}/{self.max_reconnect_attempts})"
)
# Flat 1 second delay for all attempts
await asyncio.sleep(1)
try:
result = await self.connection.connect()
if result is not None:
self._is_connected = True
self._reconnect_attempts = 0
# Invoke reconnect callback (e.g. send_appstart) if provided
if self._reconnect_callback is not None:
try:
await self._reconnect_callback()
except Exception as cb_err:
logger.warning(
f"Reconnect callback failed: {cb_err}"
)
try:
result = await self.connection.connect()
if result is not None:
self._is_connected = True
self._reconnect_attempts = 0
await self._emit_event(
EventType.CONNECTED,
{"connection_info": result, "reconnected": True},
)
logger.debug("Reconnected successfully")
else:
# Reconnection failed, try again if we haven't exceeded max attempts
if self._reconnect_attempts < self.max_reconnect_attempts:
self._reconnect_task = asyncio.create_task(
self._attempt_reconnect()
)
else:
await self._emit_event(
EventType.DISCONNECTED,
{"reason": "reconnect_failed", "max_attempts_exceeded": True},
EventType.CONNECTED,
{"connection_info": result, "reconnected": True},
)
except Exception as e:
logger.debug(f"Reconnection attempt failed: {e}")
if self._reconnect_attempts < self.max_reconnect_attempts:
self._reconnect_task = asyncio.create_task(self._attempt_reconnect())
else:
await self._emit_event(
EventType.DISCONNECTED,
{"reason": f"reconnect_error: {e}", "max_attempts_exceeded": True},
)
logger.debug("Reconnected successfully")
return
except Exception as e:
logger.debug(f"Reconnection attempt failed: {e}")
# All attempts exhausted
await self._emit_event(
EventType.DISCONNECTED,
{"reason": "reconnect_failed", "max_attempts_exceeded": True},
)
async def _emit_event(self, event_type: EventType, payload: dict):
"""Emit connection events if dispatcher is available."""

View file

@ -28,10 +28,17 @@ class MeshCore:
auto_reconnect: bool = False,
max_reconnect_attempts: int = 3,
):
# Wrap connection with ConnectionManager
# Wrap connection with ConnectionManager.
# The reconnect callback ensures send_appstart() runs after every
# transport-level reconnect, which is required by firmware to
# initialize the session (F02).
self.dispatcher = EventDispatcher()
self.connection_manager = ConnectionManager(
cx, self.dispatcher, auto_reconnect, max_reconnect_attempts
cx,
self.dispatcher,
auto_reconnect,
max_reconnect_attempts,
reconnect_callback=self._on_reconnect,
)
self.cx = self.connection_manager # For backward compatibility
@ -174,6 +181,15 @@ class MeshCore:
return None
return mc
async def _on_reconnect(self):
"""Callback invoked by ConnectionManager after a successful reconnect.
Firmware requires CMD_APP_START after every transport-level connection
to initialize the session. MeshCore.connect() does this on the initial
connection; this callback ensures it also happens on reconnects (F02).
"""
await self.commands.send_appstart()
async def connect(self):
await self.dispatcher.start()
result = await self.connection_manager.connect()

View file

@ -59,10 +59,7 @@ class TCPConnection:
)
logger.info("TCP Connection started")
future = asyncio.Future()
future.set_result(self.host)
return future
return self.host
def set_reader(self, reader):
self.reader = reader

View file

@ -0,0 +1,294 @@
"""Tests for reconnect-path fixes."""
import asyncio
import pytest
from meshcore.connection_manager import ConnectionManager
from meshcore.events import Event, EventDispatcher, EventType
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
class FakeConnection:
"""Minimal stub that satisfies ConnectionProtocol."""
def __init__(self, connect_results=None):
"""
Args:
connect_results: iterator of return values for successive
connect() calls. ``None`` means soft failure; a string
means success; raising is also supported via sentinel.
"""
self._connect_results = list(connect_results or ["ok"])
self._call_index = 0
self.reader = None
async def connect(self):
if self._call_index < len(self._connect_results):
result = self._connect_results[self._call_index]
self._call_index += 1
else:
result = self._connect_results[-1]
if isinstance(result, Exception):
raise result
return result
async def disconnect(self):
pass
async def send(self, data):
pass
def set_reader(self, reader):
self.reader = reader
class RaisingConnection(FakeConnection):
"""Connection that raises on every connect() attempt."""
def __init__(self, exc=None):
super().__init__()
self._exc = exc or ConnectionError("boom")
async def connect(self):
raise self._exc
class _EventCollector:
"""Subscribes to all events and records them."""
def __init__(self, dispatcher: EventDispatcher):
self.events: list[Event] = []
dispatcher.subscribe(None, self._on_event)
async def _on_event(self, event: Event):
self.events.append(event)
# ---------------------------------------------------------------------------
# TCP connect() should return a plain value, not an asyncio.Future
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_tcp_connect_returns_plain_string():
"""TCPConnection.connect() returns self.host (a plain string), not an
asyncio.Future. We test indirectly via ConnectionManager the
CONNECTED event payload should contain a plain string, not a Future
object."""
conn = FakeConnection(connect_results=["10.0.0.1"])
dispatcher = EventDispatcher()
await dispatcher.start()
try:
collector = _EventCollector(dispatcher)
mgr = ConnectionManager(conn, dispatcher)
result = await mgr.connect()
assert result == "10.0.0.1"
# Give the dispatcher a moment to deliver the event
await asyncio.sleep(0.05)
connected_events = [e for e in collector.events if e.type == EventType.CONNECTED]
assert len(connected_events) == 1
payload = connected_events[0].payload
assert payload["connection_info"] == "10.0.0.1"
# The payload value must NOT be an asyncio.Future
assert not isinstance(payload["connection_info"], asyncio.Future)
finally:
await dispatcher.stop()
# ---------------------------------------------------------------------------
# Reconnect attempts must not compound (no tail-recursive create_task)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_reconnect_loop_does_not_compound():
"""_attempt_reconnect must use a single iterative loop. After
max_reconnect_attempts failures, exactly that many connect() calls
should have been made no exponential fan-out from orphaned tasks."""
# All attempts fail (return None)
conn = FakeConnection(connect_results=[None, None, None, None])
dispatcher = EventDispatcher()
await dispatcher.start()
try:
collector = _EventCollector(dispatcher)
mgr = ConnectionManager(
conn, dispatcher, auto_reconnect=True, max_reconnect_attempts=3,
)
mgr._is_connected = True # simulate a live connection
await mgr.handle_disconnect("test_disconnect")
# Wait for the reconnect loop to exhaust all attempts
# (3 attempts × 1s sleep each, but we can just await the task)
if mgr._reconnect_task:
await mgr._reconnect_task
# Exactly 3 connect() calls should have been made
assert conn._call_index == 3
# A DISCONNECTED event with max_attempts_exceeded should have fired
await asyncio.sleep(0.05)
disconnected = [e for e in collector.events if e.type == EventType.DISCONNECTED]
assert len(disconnected) == 1
assert disconnected[0].payload.get("max_attempts_exceeded") is True
finally:
await dispatcher.stop()
@pytest.mark.asyncio
async def test_disconnect_cancels_reconnect_loop():
"""disconnect() during an active reconnect loop must cancel the
single task cleanly no orphaned tasks left running."""
# Simulate a connection that always fails (returns None), giving us
# time to call disconnect() mid-loop.
conn = FakeConnection(connect_results=[None, None, None, None, None])
dispatcher = EventDispatcher()
await dispatcher.start()
try:
mgr = ConnectionManager(
conn, dispatcher, auto_reconnect=True, max_reconnect_attempts=5,
)
mgr._is_connected = True
await mgr.handle_disconnect("test_disconnect")
# Let the first attempt start (wait just past the 1s sleep)
await asyncio.sleep(1.2)
assert conn._call_index >= 1 # at least one attempt made
# Now disconnect — should cancel the loop
await mgr.disconnect()
assert mgr._reconnect_task is None
calls_at_cancel = conn._call_index
# Wait a bit and confirm no more attempts happened
await asyncio.sleep(2)
assert conn._call_index == calls_at_cancel
finally:
await dispatcher.stop()
# ---------------------------------------------------------------------------
# reconnect_callback (send_appstart) is called after reconnect
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_reconnect_callback_called_after_reconnect():
"""When ConnectionManager reconnects successfully, the
reconnect_callback (e.g. send_appstart) must be invoked."""
callback_called = []
async def fake_appstart():
callback_called.append(True)
# First connect() fails (None), second succeeds
conn = FakeConnection(connect_results=[None, "10.0.0.1"])
dispatcher = EventDispatcher()
await dispatcher.start()
try:
mgr = ConnectionManager(
conn, dispatcher,
auto_reconnect=True,
max_reconnect_attempts=3,
reconnect_callback=fake_appstart,
)
mgr._is_connected = True
await mgr.handle_disconnect("test_disconnect")
if mgr._reconnect_task:
await mgr._reconnect_task
assert len(callback_called) == 1
finally:
await dispatcher.stop()
@pytest.mark.asyncio
async def test_reconnect_callback_failure_does_not_crash_loop():
"""If the reconnect_callback raises, the reconnect still counts as
successful (transport is up) the callback failure is logged but
does not crash the loop or leave the manager in a broken state."""
async def failing_callback():
raise RuntimeError("appstart failed")
# connect() succeeds on first attempt
conn = FakeConnection(connect_results=["10.0.0.1"])
dispatcher = EventDispatcher()
await dispatcher.start()
try:
collector = _EventCollector(dispatcher)
mgr = ConnectionManager(
conn, dispatcher,
auto_reconnect=True,
max_reconnect_attempts=3,
reconnect_callback=failing_callback,
)
mgr._is_connected = True
await mgr.handle_disconnect("test_disconnect")
if mgr._reconnect_task:
await mgr._reconnect_task
# Despite callback failure, CONNECTED event should have fired
await asyncio.sleep(0.05)
connected = [e for e in collector.events if e.type == EventType.CONNECTED]
assert len(connected) == 1
assert mgr._is_connected is True
finally:
await dispatcher.stop()
# ---------------------------------------------------------------------------
# connect() returning None is a soft failure (BLE scan miss)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_connect_none_is_soft_failure():
"""When connect() returns None (e.g. BLE scan found no device),
ConnectionManager.connect() should NOT set _is_connected and should
NOT emit a CONNECTED event."""
conn = FakeConnection(connect_results=[None])
dispatcher = EventDispatcher()
await dispatcher.start()
try:
collector = _EventCollector(dispatcher)
mgr = ConnectionManager(conn, dispatcher)
result = await mgr.connect()
assert result is None
assert mgr._is_connected is False
await asyncio.sleep(0.05)
connected = [e for e in collector.events if e.type == EventType.CONNECTED]
assert len(connected) == 0
finally:
await dispatcher.stop()
@pytest.mark.asyncio
async def test_no_reconnect_callback_is_noop():
"""When no reconnect_callback is provided (backwards compat for
direct ConnectionManager users), reconnect should still work."""
conn = FakeConnection(connect_results=["10.0.0.1"])
dispatcher = EventDispatcher()
await dispatcher.start()
try:
mgr = ConnectionManager(
conn, dispatcher,
auto_reconnect=True,
max_reconnect_attempts=3,
# No reconnect_callback — default None
)
mgr._is_connected = True
await mgr.handle_disconnect("test_disconnect")
if mgr._reconnect_task:
await mgr._reconnect_task
assert mgr._is_connected is True
finally:
await dispatcher.stop()