diff --git a/src/meshcore/ble_cx.py b/src/meshcore/ble_cx.py index 0ce06d9..86df64d 100644 --- a/src/meshcore/ble_cx.py +++ b/src/meshcore/ble_cx.py @@ -51,6 +51,14 @@ class BLEConnection: self.pin = pin self.rx_char = None self._disconnect_callback = None + self._background_tasks: set[asyncio.Task] = set() + + def _spawn_background(self, coro) -> asyncio.Task: + """Create a tracked background task (prevents GC of fire-and-forget tasks).""" + task = asyncio.create_task(coro) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) + return task async def connect(self): """ @@ -155,7 +163,7 @@ class BLEConnection: self.device = self._user_provided_device if self._disconnect_callback: - asyncio.create_task(self._disconnect_callback("ble_disconnect")) + self._spawn_background(self._disconnect_callback("ble_disconnect")) def set_disconnect_callback(self, callback): """Set callback to handle disconnections.""" @@ -166,7 +174,7 @@ class BLEConnection: def handle_rx(self, _: BleakGATTCharacteristic, data: bytearray): if self.reader is not None: - asyncio.create_task(self.reader.handle_rx(data)) + self._spawn_background(self.reader.handle_rx(data)) async def send(self, data): if not self.client: diff --git a/src/meshcore/commands/base.py b/src/meshcore/commands/base.py index 9e0f00e..8983eed 100644 --- a/src/meshcore/commands/base.py +++ b/src/meshcore/commands/base.py @@ -58,17 +58,32 @@ def _validate_destination(dst: DestinationType, prefix_length: int = 6) -> bytes class CommandHandlerBase: + """Base class for command handlers. + + .. note:: + The internal ``asyncio.Lock`` is created lazily on first access + so that it binds to the correct running event loop (required for + Python 3.9/3.10 compatibility). + """ + DEFAULT_TIMEOUT = 5.0 def __init__(self, default_timeout: Optional[float] = None): self._sender_func: Optional[Callable[[bytes], Coroutine[Any, Any, None]]] = None self._reader: Optional[MessageReader] = None self.dispatcher: Optional[EventDispatcher] = None - self._mesh_request_lock = asyncio.Lock() + self.__mesh_request_lock: Optional[asyncio.Lock] = None self.default_timeout = ( default_timeout if default_timeout is not None else self.DEFAULT_TIMEOUT ) + @property + def _mesh_request_lock(self) -> asyncio.Lock: + """Lazy-init lock so it binds to the running loop, not import-time.""" + if self.__mesh_request_lock is None: + self.__mesh_request_lock = asyncio.Lock() + return self.__mesh_request_lock + def set_connection(self, connection: Any) -> None: async def sender(data: bytes) -> None: await connection.send(data) @@ -170,7 +185,7 @@ class CommandHandlerBase: futures: List[asyncio.Future] = [] subscriptions = [] - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() for event_type in expected_events: future = loop.create_future() diff --git a/src/meshcore/events.py b/src/meshcore/events.py index f8a7521..c9c0b6a 100644 --- a/src/meshcore/events.py +++ b/src/meshcore/events.py @@ -129,11 +129,28 @@ class Subscription: class EventDispatcher: + """Event dispatch engine. + + .. note:: + ``start()`` must be called before dispatching or processing events. + The internal ``asyncio.Queue`` is created lazily inside ``start()`` + so that it binds to the correct running event loop (required for + Python 3.9/3.10 compatibility). + """ + def __init__(self): - self.queue: asyncio.Queue[Event] = asyncio.Queue() + self.queue: Optional[asyncio.Queue[Event]] = None self.subscriptions: List[Subscription] = [] self.running = False self._task = None + self._background_tasks: set[asyncio.Task] = set() + + def _spawn_background(self, coro) -> asyncio.Task: + """Create a tracked background task (prevents GC of fire-and-forget tasks).""" + task = asyncio.create_task(coro) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) + return task def subscribe( self, @@ -166,6 +183,10 @@ class EventDispatcher: self.subscriptions.remove(subscription) async def dispatch(self, event: Event): + if self.queue is None: + raise RuntimeError( + "EventDispatcher.start() must be called before dispatching events" + ) await self.queue.put(event) async def _process_events(self): @@ -197,7 +218,7 @@ class EventDispatcher: # returns - avoids the race where create_task schedules the callback after # the waiter has already timed out with done=set(). if asyncio.iscoroutinefunction(subscription.callback): - asyncio.create_task(self._execute_callback(subscription.callback, event.clone())) + self._spawn_background(self._execute_callback(subscription.callback, event.clone())) else: try: subscription.callback(event.clone()) @@ -220,6 +241,8 @@ class EventDispatcher: async def start(self): if not self.running: + if self.queue is None: + self.queue = asyncio.Queue() self.running = True self._task = asyncio.create_task(self._process_events()) @@ -227,7 +250,12 @@ class EventDispatcher: if self.running: self.running = False if self._task: - await self.queue.join() + if self.queue is not None: + await self.queue.join() + # Wait for any in-flight async callbacks to complete before + # tearing down (F07: task_done fires before callbacks finish). + if self._background_tasks: + await asyncio.gather(*self._background_tasks, return_exceptions=True) self._task.cancel() try: await self._task diff --git a/src/meshcore/serial_cx.py b/src/meshcore/serial_cx.py index 61163bd..91d65d6 100644 --- a/src/meshcore/serial_cx.py +++ b/src/meshcore/serial_cx.py @@ -20,11 +20,19 @@ class SerialConnection: self._disconnect_callback = None self.cx_dly = cx_dly self._connected_event = asyncio.Event() + self._background_tasks: set[asyncio.Task] = set() self.frame_expected_size = 0 self.inframe = b"" self.header = b"" + def _spawn_background(self, coro) -> asyncio.Task: + """Create a tracked background task (prevents GC of fire-and-forget tasks).""" + task = asyncio.create_task(coro) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) + return task + class MCSerialClientProtocol(asyncio.Protocol): def __init__(self, cx): self.cx = cx @@ -44,7 +52,7 @@ class SerialConnection: self.cx._connected_event.clear() if self.cx._disconnect_callback: - asyncio.create_task(self.cx._disconnect_callback("serial_disconnect")) + self.cx._spawn_background(self.cx._disconnect_callback("serial_disconnect")) def pause_writing(self): logger.debug("pause writing") @@ -114,7 +122,7 @@ class SerialConnection: data = data[upbound:] if self.reader is not None: # feed meshcore reader - asyncio.create_task(self.reader.handle_rx(self.inframe)) + self._spawn_background(self.reader.handle_rx(self.inframe)) # reset inframe self.inframe = b"" self.header = b"" diff --git a/src/meshcore/tcp_cx.py b/src/meshcore/tcp_cx.py index 497c3b2..ad9b9cb 100644 --- a/src/meshcore/tcp_cx.py +++ b/src/meshcore/tcp_cx.py @@ -24,6 +24,14 @@ class TCPConnection: self.frame_expected_size = 0 self.header = b"" self.inframe = b"" + self._background_tasks: set[asyncio.Task] = set() + + def _spawn_background(self, coro) -> asyncio.Task: + """Create a tracked background task (prevents GC of fire-and-forget tasks).""" + task = asyncio.create_task(coro) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) + return task class MCClientProtocol(asyncio.Protocol): def __init__(self, cx): @@ -47,7 +55,7 @@ class TCPConnection: def connection_lost(self, exc): logger.debug("TCP server closed the connection") if self.cx._disconnect_callback: - asyncio.create_task(self.cx._disconnect_callback("tcp_disconnect")) + self.cx._spawn_background(self.cx._disconnect_callback("tcp_disconnect")) async def connect(self): """ @@ -108,7 +116,7 @@ class TCPConnection: data = data[upbound:] if self.reader is not None: # feed meshcore reader - asyncio.create_task(self.reader.handle_rx(self.inframe)) + self._spawn_background(self.reader.handle_rx(self.inframe)) # reset inframe self.inframe = b"" self.header = b"" diff --git a/tests/unit/test_asyncio_lifecycle.py b/tests/unit/test_asyncio_lifecycle.py new file mode 100644 index 0000000..f70f2f0 --- /dev/null +++ b/tests/unit/test_asyncio_lifecycle.py @@ -0,0 +1,235 @@ +""" +Verification tests for asyncio lifecycle fixes. +""" + +import asyncio +import gc +import unittest +from unittest.mock import AsyncMock, MagicMock, patch + +from meshcore.events import Event, EventDispatcher, EventType +from meshcore.tcp_cx import TCPConnection +from meshcore.serial_cx import SerialConnection +from meshcore.commands.base import CommandHandlerBase + + +class TestBackgroundTaskTracking(unittest.TestCase): + """Fire-and-forget create_task calls must be tracked to prevent GC.""" + + def test_tcp_spawn_background_retains_task(self): + """TCP _spawn_background adds the task to _background_tasks.""" + async def _run(): + cx = TCPConnection("127.0.0.1", 5555) + completed = asyncio.Event() + + async def dummy(): + completed.set() + + task = cx._spawn_background(dummy()) + assert task in cx._background_tasks + await completed.wait() + # After completion, done_callback should have discarded it + await asyncio.sleep(0) # let done callback fire + assert task not in cx._background_tasks + + asyncio.run(_run()) + + def test_serial_spawn_background_retains_task(self): + """Serial _spawn_background adds the task to _background_tasks.""" + async def _run(): + with patch("meshcore.serial_cx.asyncio.Event") as mock_event: + mock_event.return_value = MagicMock() + cx = SerialConnection("/dev/null", 115200) + completed = asyncio.Event() + + async def dummy(): + completed.set() + + task = cx._spawn_background(dummy()) + assert task in cx._background_tasks + await completed.wait() + await asyncio.sleep(0) + assert task not in cx._background_tasks + + asyncio.run(_run()) + + def test_event_dispatcher_spawn_background_retains_task(self): + """EventDispatcher _spawn_background adds task to _background_tasks.""" + async def _run(): + dispatcher = EventDispatcher() + completed = asyncio.Event() + + async def dummy(): + completed.set() + + task = dispatcher._spawn_background(dummy()) + assert task in dispatcher._background_tasks + await completed.wait() + await asyncio.sleep(0) + assert task not in dispatcher._background_tasks + + asyncio.run(_run()) + + def test_tcp_handle_rx_uses_tracked_task(self): + """TCP handle_rx dispatches reader.handle_rx via _spawn_background.""" + async def _run(): + cx = TCPConnection("127.0.0.1", 5555) + reader = AsyncMock() + reader.handle_rx = AsyncMock() + cx.set_reader(reader) + + # Build a minimal valid frame: 0x3e + 2-byte LE size + payload + payload = b"\x01\x02\x03" + size = len(payload).to_bytes(2, "little") + frame = b"\x3e" + size + payload + + cx.handle_rx(frame) + # Task should be tracked + assert len(cx._background_tasks) == 1 + # Let task complete + await asyncio.sleep(0.05) + reader.handle_rx.assert_awaited_once_with(payload) + + asyncio.run(_run()) + + def test_tcp_connection_lost_uses_tracked_task(self): + """TCP connection_lost dispatches disconnect callback via _spawn_background.""" + async def _run(): + cx = TCPConnection("127.0.0.1", 5555) + callback = AsyncMock() + cx.set_disconnect_callback(callback) + + protocol = cx.MCClientProtocol(cx) + protocol.connection_lost(None) + + assert len(cx._background_tasks) == 1 + await asyncio.sleep(0.05) + callback.assert_awaited_once_with("tcp_disconnect") + + asyncio.run(_run()) + + def test_gc_does_not_cancel_tracked_tasks(self): + """Tracked tasks survive GC pressure (the whole point of tracking).""" + async def _run(): + cx = TCPConnection("127.0.0.1", 5555) + result = [] + + async def slow_task(): + await asyncio.sleep(0.05) + result.append("done") + + cx._spawn_background(slow_task()) + # Force GC — untracked tasks could be collected here + gc.collect() + await asyncio.sleep(0.1) + assert result == ["done"] + + asyncio.run(_run()) + + +class TestTaskDoneCorrectness(unittest.TestCase): + """EventDispatcher.stop() must wait for in-flight async callbacks.""" + + def test_stop_waits_for_async_callbacks(self): + """stop() should not return until async callbacks have completed.""" + async def _run(): + dispatcher = EventDispatcher() + await dispatcher.start() + + callback_completed = False + + async def slow_callback(event): + nonlocal callback_completed + await asyncio.sleep(0.1) + callback_completed = True + + dispatcher.subscribe(EventType.OK, slow_callback) + await dispatcher.dispatch(Event(EventType.OK, {})) + + # Give the dispatch loop a moment to pick up the event + await asyncio.sleep(0.02) + + # stop() should wait for slow_callback to finish + await dispatcher.stop() + assert callback_completed, "stop() returned before async callback completed" + + asyncio.run(_run()) + + +class TestDeferredPrimitiveConstruction(unittest.TestCase): + """Queue and Lock must not bind to import-time loop.""" + + def test_event_dispatcher_queue_is_none_before_start(self): + """EventDispatcher.queue should be None until start() is called.""" + dispatcher = EventDispatcher() + assert dispatcher.queue is None + + def test_event_dispatcher_queue_created_on_start(self): + """start() creates the queue.""" + async def _run(): + dispatcher = EventDispatcher() + assert dispatcher.queue is None + await dispatcher.start() + assert dispatcher.queue is not None + assert isinstance(dispatcher.queue, asyncio.Queue) + await dispatcher.stop() + + asyncio.run(_run()) + + def test_event_dispatcher_dispatch_before_start_raises(self): + """dispatch() before start() should raise RuntimeError.""" + async def _run(): + dispatcher = EventDispatcher() + with self.assertRaises(RuntimeError): + await dispatcher.dispatch(Event(EventType.OK, {})) + + asyncio.run(_run()) + + def test_command_handler_lock_is_none_before_use(self): + """CommandHandlerBase lock should be None until first access.""" + handler = CommandHandlerBase() + assert handler._CommandHandlerBase__mesh_request_lock is None + + def test_command_handler_lock_created_on_access(self): + """Accessing _mesh_request_lock creates it lazily.""" + async def _run(): + handler = CommandHandlerBase() + lock = handler._mesh_request_lock + assert isinstance(lock, asyncio.Lock) + # Second access returns same instance + assert handler._mesh_request_lock is lock + + asyncio.run(_run()) + + +class TestGetRunningLoop(unittest.TestCase): + """get_event_loop() replaced with get_running_loop() in send().""" + + def test_send_uses_get_running_loop(self): + """send() should call get_running_loop, not get_event_loop.""" + async def _run(): + handler = CommandHandlerBase() + dispatcher = EventDispatcher() + await dispatcher.start() + handler.set_dispatcher(dispatcher) + + mock_sender = AsyncMock() + handler._sender_func = mock_sender + + # Patch get_running_loop to verify it's called + with patch("meshcore.commands.base.asyncio.get_running_loop", wraps=asyncio.get_running_loop) as mock_grl: + # send with expected_events triggers the loop = asyncio.get_running_loop() path + result = await handler.send( + b"\x01", + expected_events=[EventType.OK], + timeout=0.05, + ) + mock_grl.assert_called() + + await dispatcher.stop() + + asyncio.run(_run()) + + +if __name__ == "__main__": + unittest.main()