This commit is contained in:
mwolter805 2026-04-18 05:15:42 -07:00 committed by GitHub
commit 480cc75d28
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 313 additions and 11 deletions

View file

@ -51,6 +51,14 @@ class BLEConnection:
self.pin = pin
self.rx_char = None
self._disconnect_callback = None
self._background_tasks: set[asyncio.Task] = set()
def _spawn_background(self, coro) -> asyncio.Task:
"""Create a tracked background task (prevents GC of fire-and-forget tasks)."""
task = asyncio.create_task(coro)
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
return task
async def connect(self):
"""
@ -155,7 +163,7 @@ class BLEConnection:
self.device = self._user_provided_device
if self._disconnect_callback:
asyncio.create_task(self._disconnect_callback("ble_disconnect"))
self._spawn_background(self._disconnect_callback("ble_disconnect"))
def set_disconnect_callback(self, callback):
"""Set callback to handle disconnections."""
@ -166,7 +174,7 @@ class BLEConnection:
def handle_rx(self, _: BleakGATTCharacteristic, data: bytearray):
if self.reader is not None:
asyncio.create_task(self.reader.handle_rx(data))
self._spawn_background(self.reader.handle_rx(data))
async def send(self, data):
if not self.client:

View file

@ -58,17 +58,32 @@ def _validate_destination(dst: DestinationType, prefix_length: int = 6) -> bytes
class CommandHandlerBase:
"""Base class for command handlers.
.. note::
The internal ``asyncio.Lock`` is created lazily on first access
so that it binds to the correct running event loop (required for
Python 3.9/3.10 compatibility).
"""
DEFAULT_TIMEOUT = 5.0
def __init__(self, default_timeout: Optional[float] = None):
self._sender_func: Optional[Callable[[bytes], Coroutine[Any, Any, None]]] = None
self._reader: Optional[MessageReader] = None
self.dispatcher: Optional[EventDispatcher] = None
self._mesh_request_lock = asyncio.Lock()
self.__mesh_request_lock: Optional[asyncio.Lock] = None
self.default_timeout = (
default_timeout if default_timeout is not None else self.DEFAULT_TIMEOUT
)
@property
def _mesh_request_lock(self) -> asyncio.Lock:
"""Lazy-init lock so it binds to the running loop, not import-time."""
if self.__mesh_request_lock is None:
self.__mesh_request_lock = asyncio.Lock()
return self.__mesh_request_lock
def set_connection(self, connection: Any) -> None:
async def sender(data: bytes) -> None:
await connection.send(data)
@ -170,7 +185,7 @@ class CommandHandlerBase:
futures: List[asyncio.Future] = []
subscriptions = []
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
for event_type in expected_events:
future = loop.create_future()

View file

@ -129,11 +129,28 @@ class Subscription:
class EventDispatcher:
"""Event dispatch engine.
.. note::
``start()`` must be called before dispatching or processing events.
The internal ``asyncio.Queue`` is created lazily inside ``start()``
so that it binds to the correct running event loop (required for
Python 3.9/3.10 compatibility).
"""
def __init__(self):
self.queue: asyncio.Queue[Event] = asyncio.Queue()
self.queue: Optional[asyncio.Queue[Event]] = None
self.subscriptions: List[Subscription] = []
self.running = False
self._task = None
self._background_tasks: set[asyncio.Task] = set()
def _spawn_background(self, coro) -> asyncio.Task:
"""Create a tracked background task (prevents GC of fire-and-forget tasks)."""
task = asyncio.create_task(coro)
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
return task
def subscribe(
self,
@ -166,6 +183,10 @@ class EventDispatcher:
self.subscriptions.remove(subscription)
async def dispatch(self, event: Event):
if self.queue is None:
raise RuntimeError(
"EventDispatcher.start() must be called before dispatching events"
)
await self.queue.put(event)
async def _process_events(self):
@ -197,7 +218,7 @@ class EventDispatcher:
# returns - avoids the race where create_task schedules the callback after
# the waiter has already timed out with done=set().
if asyncio.iscoroutinefunction(subscription.callback):
asyncio.create_task(self._execute_callback(subscription.callback, event.clone()))
self._spawn_background(self._execute_callback(subscription.callback, event.clone()))
else:
try:
subscription.callback(event.clone())
@ -220,6 +241,8 @@ class EventDispatcher:
async def start(self):
if not self.running:
if self.queue is None:
self.queue = asyncio.Queue()
self.running = True
self._task = asyncio.create_task(self._process_events())
@ -227,7 +250,12 @@ class EventDispatcher:
if self.running:
self.running = False
if self._task:
await self.queue.join()
if self.queue is not None:
await self.queue.join()
# Wait for any in-flight async callbacks to complete before
# tearing down (F07: task_done fires before callbacks finish).
if self._background_tasks:
await asyncio.gather(*self._background_tasks, return_exceptions=True)
self._task.cancel()
try:
await self._task

View file

@ -20,11 +20,19 @@ class SerialConnection:
self._disconnect_callback = None
self.cx_dly = cx_dly
self._connected_event = asyncio.Event()
self._background_tasks: set[asyncio.Task] = set()
self.frame_expected_size = 0
self.inframe = b""
self.header = b""
def _spawn_background(self, coro) -> asyncio.Task:
"""Create a tracked background task (prevents GC of fire-and-forget tasks)."""
task = asyncio.create_task(coro)
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
return task
class MCSerialClientProtocol(asyncio.Protocol):
def __init__(self, cx):
self.cx = cx
@ -44,7 +52,7 @@ class SerialConnection:
self.cx._connected_event.clear()
if self.cx._disconnect_callback:
asyncio.create_task(self.cx._disconnect_callback("serial_disconnect"))
self.cx._spawn_background(self.cx._disconnect_callback("serial_disconnect"))
def pause_writing(self):
logger.debug("pause writing")
@ -114,7 +122,7 @@ class SerialConnection:
data = data[upbound:]
if self.reader is not None:
# feed meshcore reader
asyncio.create_task(self.reader.handle_rx(self.inframe))
self._spawn_background(self.reader.handle_rx(self.inframe))
# reset inframe
self.inframe = b""
self.header = b""

View file

@ -24,6 +24,14 @@ class TCPConnection:
self.frame_expected_size = 0
self.header = b""
self.inframe = b""
self._background_tasks: set[asyncio.Task] = set()
def _spawn_background(self, coro) -> asyncio.Task:
"""Create a tracked background task (prevents GC of fire-and-forget tasks)."""
task = asyncio.create_task(coro)
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
return task
class MCClientProtocol(asyncio.Protocol):
def __init__(self, cx):
@ -47,7 +55,7 @@ class TCPConnection:
def connection_lost(self, exc):
logger.debug("TCP server closed the connection")
if self.cx._disconnect_callback:
asyncio.create_task(self.cx._disconnect_callback("tcp_disconnect"))
self.cx._spawn_background(self.cx._disconnect_callback("tcp_disconnect"))
async def connect(self):
"""
@ -108,7 +116,7 @@ class TCPConnection:
data = data[upbound:]
if self.reader is not None:
# feed meshcore reader
asyncio.create_task(self.reader.handle_rx(self.inframe))
self._spawn_background(self.reader.handle_rx(self.inframe))
# reset inframe
self.inframe = b""
self.header = b""

View file

@ -0,0 +1,235 @@
"""
Verification tests for asyncio lifecycle fixes.
"""
import asyncio
import gc
import unittest
from unittest.mock import AsyncMock, MagicMock, patch
from meshcore.events import Event, EventDispatcher, EventType
from meshcore.tcp_cx import TCPConnection
from meshcore.serial_cx import SerialConnection
from meshcore.commands.base import CommandHandlerBase
class TestBackgroundTaskTracking(unittest.TestCase):
"""Fire-and-forget create_task calls must be tracked to prevent GC."""
def test_tcp_spawn_background_retains_task(self):
"""TCP _spawn_background adds the task to _background_tasks."""
async def _run():
cx = TCPConnection("127.0.0.1", 5555)
completed = asyncio.Event()
async def dummy():
completed.set()
task = cx._spawn_background(dummy())
assert task in cx._background_tasks
await completed.wait()
# After completion, done_callback should have discarded it
await asyncio.sleep(0) # let done callback fire
assert task not in cx._background_tasks
asyncio.run(_run())
def test_serial_spawn_background_retains_task(self):
"""Serial _spawn_background adds the task to _background_tasks."""
async def _run():
with patch("meshcore.serial_cx.asyncio.Event") as mock_event:
mock_event.return_value = MagicMock()
cx = SerialConnection("/dev/null", 115200)
completed = asyncio.Event()
async def dummy():
completed.set()
task = cx._spawn_background(dummy())
assert task in cx._background_tasks
await completed.wait()
await asyncio.sleep(0)
assert task not in cx._background_tasks
asyncio.run(_run())
def test_event_dispatcher_spawn_background_retains_task(self):
"""EventDispatcher _spawn_background adds task to _background_tasks."""
async def _run():
dispatcher = EventDispatcher()
completed = asyncio.Event()
async def dummy():
completed.set()
task = dispatcher._spawn_background(dummy())
assert task in dispatcher._background_tasks
await completed.wait()
await asyncio.sleep(0)
assert task not in dispatcher._background_tasks
asyncio.run(_run())
def test_tcp_handle_rx_uses_tracked_task(self):
"""TCP handle_rx dispatches reader.handle_rx via _spawn_background."""
async def _run():
cx = TCPConnection("127.0.0.1", 5555)
reader = AsyncMock()
reader.handle_rx = AsyncMock()
cx.set_reader(reader)
# Build a minimal valid frame: 0x3e + 2-byte LE size + payload
payload = b"\x01\x02\x03"
size = len(payload).to_bytes(2, "little")
frame = b"\x3e" + size + payload
cx.handle_rx(frame)
# Task should be tracked
assert len(cx._background_tasks) == 1
# Let task complete
await asyncio.sleep(0.05)
reader.handle_rx.assert_awaited_once_with(payload)
asyncio.run(_run())
def test_tcp_connection_lost_uses_tracked_task(self):
"""TCP connection_lost dispatches disconnect callback via _spawn_background."""
async def _run():
cx = TCPConnection("127.0.0.1", 5555)
callback = AsyncMock()
cx.set_disconnect_callback(callback)
protocol = cx.MCClientProtocol(cx)
protocol.connection_lost(None)
assert len(cx._background_tasks) == 1
await asyncio.sleep(0.05)
callback.assert_awaited_once_with("tcp_disconnect")
asyncio.run(_run())
def test_gc_does_not_cancel_tracked_tasks(self):
"""Tracked tasks survive GC pressure (the whole point of tracking)."""
async def _run():
cx = TCPConnection("127.0.0.1", 5555)
result = []
async def slow_task():
await asyncio.sleep(0.05)
result.append("done")
cx._spawn_background(slow_task())
# Force GC — untracked tasks could be collected here
gc.collect()
await asyncio.sleep(0.1)
assert result == ["done"]
asyncio.run(_run())
class TestTaskDoneCorrectness(unittest.TestCase):
"""EventDispatcher.stop() must wait for in-flight async callbacks."""
def test_stop_waits_for_async_callbacks(self):
"""stop() should not return until async callbacks have completed."""
async def _run():
dispatcher = EventDispatcher()
await dispatcher.start()
callback_completed = False
async def slow_callback(event):
nonlocal callback_completed
await asyncio.sleep(0.1)
callback_completed = True
dispatcher.subscribe(EventType.OK, slow_callback)
await dispatcher.dispatch(Event(EventType.OK, {}))
# Give the dispatch loop a moment to pick up the event
await asyncio.sleep(0.02)
# stop() should wait for slow_callback to finish
await dispatcher.stop()
assert callback_completed, "stop() returned before async callback completed"
asyncio.run(_run())
class TestDeferredPrimitiveConstruction(unittest.TestCase):
"""Queue and Lock must not bind to import-time loop."""
def test_event_dispatcher_queue_is_none_before_start(self):
"""EventDispatcher.queue should be None until start() is called."""
dispatcher = EventDispatcher()
assert dispatcher.queue is None
def test_event_dispatcher_queue_created_on_start(self):
"""start() creates the queue."""
async def _run():
dispatcher = EventDispatcher()
assert dispatcher.queue is None
await dispatcher.start()
assert dispatcher.queue is not None
assert isinstance(dispatcher.queue, asyncio.Queue)
await dispatcher.stop()
asyncio.run(_run())
def test_event_dispatcher_dispatch_before_start_raises(self):
"""dispatch() before start() should raise RuntimeError."""
async def _run():
dispatcher = EventDispatcher()
with self.assertRaises(RuntimeError):
await dispatcher.dispatch(Event(EventType.OK, {}))
asyncio.run(_run())
def test_command_handler_lock_is_none_before_use(self):
"""CommandHandlerBase lock should be None until first access."""
handler = CommandHandlerBase()
assert handler._CommandHandlerBase__mesh_request_lock is None
def test_command_handler_lock_created_on_access(self):
"""Accessing _mesh_request_lock creates it lazily."""
async def _run():
handler = CommandHandlerBase()
lock = handler._mesh_request_lock
assert isinstance(lock, asyncio.Lock)
# Second access returns same instance
assert handler._mesh_request_lock is lock
asyncio.run(_run())
class TestGetRunningLoop(unittest.TestCase):
"""get_event_loop() replaced with get_running_loop() in send()."""
def test_send_uses_get_running_loop(self):
"""send() should call get_running_loop, not get_event_loop."""
async def _run():
handler = CommandHandlerBase()
dispatcher = EventDispatcher()
await dispatcher.start()
handler.set_dispatcher(dispatcher)
mock_sender = AsyncMock()
handler._sender_func = mock_sender
# Patch get_running_loop to verify it's called
with patch("meshcore.commands.base.asyncio.get_running_loop", wraps=asyncio.get_running_loop) as mock_grl:
# send with expected_events triggers the loop = asyncio.get_running_loop() path
result = await handler.send(
b"\x01",
expected_events=[EventType.OK],
timeout=0.05,
)
mock_grl.assert_called()
await dispatcher.stop()
asyncio.run(_run())
if __name__ == "__main__":
unittest.main()