mirror of
https://github.com/meshcore-dev/meshcore_py.git
synced 2026-04-20 22:13:49 +00:00
Merge 4fddbffa3d into fbf84cbdac
This commit is contained in:
commit
480cc75d28
6 changed files with 313 additions and 11 deletions
|
|
@ -51,6 +51,14 @@ class BLEConnection:
|
|||
self.pin = pin
|
||||
self.rx_char = None
|
||||
self._disconnect_callback = None
|
||||
self._background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
def _spawn_background(self, coro) -> asyncio.Task:
|
||||
"""Create a tracked background task (prevents GC of fire-and-forget tasks)."""
|
||||
task = asyncio.create_task(coro)
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
return task
|
||||
|
||||
async def connect(self):
|
||||
"""
|
||||
|
|
@ -155,7 +163,7 @@ class BLEConnection:
|
|||
self.device = self._user_provided_device
|
||||
|
||||
if self._disconnect_callback:
|
||||
asyncio.create_task(self._disconnect_callback("ble_disconnect"))
|
||||
self._spawn_background(self._disconnect_callback("ble_disconnect"))
|
||||
|
||||
def set_disconnect_callback(self, callback):
|
||||
"""Set callback to handle disconnections."""
|
||||
|
|
@ -166,7 +174,7 @@ class BLEConnection:
|
|||
|
||||
def handle_rx(self, _: BleakGATTCharacteristic, data: bytearray):
|
||||
if self.reader is not None:
|
||||
asyncio.create_task(self.reader.handle_rx(data))
|
||||
self._spawn_background(self.reader.handle_rx(data))
|
||||
|
||||
async def send(self, data):
|
||||
if not self.client:
|
||||
|
|
|
|||
|
|
@ -58,17 +58,32 @@ def _validate_destination(dst: DestinationType, prefix_length: int = 6) -> bytes
|
|||
|
||||
|
||||
class CommandHandlerBase:
|
||||
"""Base class for command handlers.
|
||||
|
||||
.. note::
|
||||
The internal ``asyncio.Lock`` is created lazily on first access
|
||||
so that it binds to the correct running event loop (required for
|
||||
Python 3.9/3.10 compatibility).
|
||||
"""
|
||||
|
||||
DEFAULT_TIMEOUT = 5.0
|
||||
|
||||
def __init__(self, default_timeout: Optional[float] = None):
|
||||
self._sender_func: Optional[Callable[[bytes], Coroutine[Any, Any, None]]] = None
|
||||
self._reader: Optional[MessageReader] = None
|
||||
self.dispatcher: Optional[EventDispatcher] = None
|
||||
self._mesh_request_lock = asyncio.Lock()
|
||||
self.__mesh_request_lock: Optional[asyncio.Lock] = None
|
||||
self.default_timeout = (
|
||||
default_timeout if default_timeout is not None else self.DEFAULT_TIMEOUT
|
||||
)
|
||||
|
||||
@property
|
||||
def _mesh_request_lock(self) -> asyncio.Lock:
|
||||
"""Lazy-init lock so it binds to the running loop, not import-time."""
|
||||
if self.__mesh_request_lock is None:
|
||||
self.__mesh_request_lock = asyncio.Lock()
|
||||
return self.__mesh_request_lock
|
||||
|
||||
def set_connection(self, connection: Any) -> None:
|
||||
async def sender(data: bytes) -> None:
|
||||
await connection.send(data)
|
||||
|
|
@ -170,7 +185,7 @@ class CommandHandlerBase:
|
|||
futures: List[asyncio.Future] = []
|
||||
subscriptions = []
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
loop = asyncio.get_running_loop()
|
||||
for event_type in expected_events:
|
||||
future = loop.create_future()
|
||||
|
||||
|
|
|
|||
|
|
@ -129,11 +129,28 @@ class Subscription:
|
|||
|
||||
|
||||
class EventDispatcher:
|
||||
"""Event dispatch engine.
|
||||
|
||||
.. note::
|
||||
``start()`` must be called before dispatching or processing events.
|
||||
The internal ``asyncio.Queue`` is created lazily inside ``start()``
|
||||
so that it binds to the correct running event loop (required for
|
||||
Python 3.9/3.10 compatibility).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.queue: asyncio.Queue[Event] = asyncio.Queue()
|
||||
self.queue: Optional[asyncio.Queue[Event]] = None
|
||||
self.subscriptions: List[Subscription] = []
|
||||
self.running = False
|
||||
self._task = None
|
||||
self._background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
def _spawn_background(self, coro) -> asyncio.Task:
|
||||
"""Create a tracked background task (prevents GC of fire-and-forget tasks)."""
|
||||
task = asyncio.create_task(coro)
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
return task
|
||||
|
||||
def subscribe(
|
||||
self,
|
||||
|
|
@ -166,6 +183,10 @@ class EventDispatcher:
|
|||
self.subscriptions.remove(subscription)
|
||||
|
||||
async def dispatch(self, event: Event):
|
||||
if self.queue is None:
|
||||
raise RuntimeError(
|
||||
"EventDispatcher.start() must be called before dispatching events"
|
||||
)
|
||||
await self.queue.put(event)
|
||||
|
||||
async def _process_events(self):
|
||||
|
|
@ -197,7 +218,7 @@ class EventDispatcher:
|
|||
# returns - avoids the race where create_task schedules the callback after
|
||||
# the waiter has already timed out with done=set().
|
||||
if asyncio.iscoroutinefunction(subscription.callback):
|
||||
asyncio.create_task(self._execute_callback(subscription.callback, event.clone()))
|
||||
self._spawn_background(self._execute_callback(subscription.callback, event.clone()))
|
||||
else:
|
||||
try:
|
||||
subscription.callback(event.clone())
|
||||
|
|
@ -220,6 +241,8 @@ class EventDispatcher:
|
|||
|
||||
async def start(self):
|
||||
if not self.running:
|
||||
if self.queue is None:
|
||||
self.queue = asyncio.Queue()
|
||||
self.running = True
|
||||
self._task = asyncio.create_task(self._process_events())
|
||||
|
||||
|
|
@ -227,7 +250,12 @@ class EventDispatcher:
|
|||
if self.running:
|
||||
self.running = False
|
||||
if self._task:
|
||||
await self.queue.join()
|
||||
if self.queue is not None:
|
||||
await self.queue.join()
|
||||
# Wait for any in-flight async callbacks to complete before
|
||||
# tearing down (F07: task_done fires before callbacks finish).
|
||||
if self._background_tasks:
|
||||
await asyncio.gather(*self._background_tasks, return_exceptions=True)
|
||||
self._task.cancel()
|
||||
try:
|
||||
await self._task
|
||||
|
|
|
|||
|
|
@ -20,11 +20,19 @@ class SerialConnection:
|
|||
self._disconnect_callback = None
|
||||
self.cx_dly = cx_dly
|
||||
self._connected_event = asyncio.Event()
|
||||
self._background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
self.frame_expected_size = 0
|
||||
self.inframe = b""
|
||||
self.header = b""
|
||||
|
||||
def _spawn_background(self, coro) -> asyncio.Task:
|
||||
"""Create a tracked background task (prevents GC of fire-and-forget tasks)."""
|
||||
task = asyncio.create_task(coro)
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
return task
|
||||
|
||||
class MCSerialClientProtocol(asyncio.Protocol):
|
||||
def __init__(self, cx):
|
||||
self.cx = cx
|
||||
|
|
@ -44,7 +52,7 @@ class SerialConnection:
|
|||
self.cx._connected_event.clear()
|
||||
|
||||
if self.cx._disconnect_callback:
|
||||
asyncio.create_task(self.cx._disconnect_callback("serial_disconnect"))
|
||||
self.cx._spawn_background(self.cx._disconnect_callback("serial_disconnect"))
|
||||
|
||||
def pause_writing(self):
|
||||
logger.debug("pause writing")
|
||||
|
|
@ -114,7 +122,7 @@ class SerialConnection:
|
|||
data = data[upbound:]
|
||||
if self.reader is not None:
|
||||
# feed meshcore reader
|
||||
asyncio.create_task(self.reader.handle_rx(self.inframe))
|
||||
self._spawn_background(self.reader.handle_rx(self.inframe))
|
||||
# reset inframe
|
||||
self.inframe = b""
|
||||
self.header = b""
|
||||
|
|
|
|||
|
|
@ -24,6 +24,14 @@ class TCPConnection:
|
|||
self.frame_expected_size = 0
|
||||
self.header = b""
|
||||
self.inframe = b""
|
||||
self._background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
def _spawn_background(self, coro) -> asyncio.Task:
|
||||
"""Create a tracked background task (prevents GC of fire-and-forget tasks)."""
|
||||
task = asyncio.create_task(coro)
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
return task
|
||||
|
||||
class MCClientProtocol(asyncio.Protocol):
|
||||
def __init__(self, cx):
|
||||
|
|
@ -47,7 +55,7 @@ class TCPConnection:
|
|||
def connection_lost(self, exc):
|
||||
logger.debug("TCP server closed the connection")
|
||||
if self.cx._disconnect_callback:
|
||||
asyncio.create_task(self.cx._disconnect_callback("tcp_disconnect"))
|
||||
self.cx._spawn_background(self.cx._disconnect_callback("tcp_disconnect"))
|
||||
|
||||
async def connect(self):
|
||||
"""
|
||||
|
|
@ -108,7 +116,7 @@ class TCPConnection:
|
|||
data = data[upbound:]
|
||||
if self.reader is not None:
|
||||
# feed meshcore reader
|
||||
asyncio.create_task(self.reader.handle_rx(self.inframe))
|
||||
self._spawn_background(self.reader.handle_rx(self.inframe))
|
||||
# reset inframe
|
||||
self.inframe = b""
|
||||
self.header = b""
|
||||
|
|
|
|||
235
tests/unit/test_asyncio_lifecycle.py
Normal file
235
tests/unit/test_asyncio_lifecycle.py
Normal file
|
|
@ -0,0 +1,235 @@
|
|||
"""
|
||||
Verification tests for asyncio lifecycle fixes.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import gc
|
||||
import unittest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from meshcore.events import Event, EventDispatcher, EventType
|
||||
from meshcore.tcp_cx import TCPConnection
|
||||
from meshcore.serial_cx import SerialConnection
|
||||
from meshcore.commands.base import CommandHandlerBase
|
||||
|
||||
|
||||
class TestBackgroundTaskTracking(unittest.TestCase):
|
||||
"""Fire-and-forget create_task calls must be tracked to prevent GC."""
|
||||
|
||||
def test_tcp_spawn_background_retains_task(self):
|
||||
"""TCP _spawn_background adds the task to _background_tasks."""
|
||||
async def _run():
|
||||
cx = TCPConnection("127.0.0.1", 5555)
|
||||
completed = asyncio.Event()
|
||||
|
||||
async def dummy():
|
||||
completed.set()
|
||||
|
||||
task = cx._spawn_background(dummy())
|
||||
assert task in cx._background_tasks
|
||||
await completed.wait()
|
||||
# After completion, done_callback should have discarded it
|
||||
await asyncio.sleep(0) # let done callback fire
|
||||
assert task not in cx._background_tasks
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
def test_serial_spawn_background_retains_task(self):
|
||||
"""Serial _spawn_background adds the task to _background_tasks."""
|
||||
async def _run():
|
||||
with patch("meshcore.serial_cx.asyncio.Event") as mock_event:
|
||||
mock_event.return_value = MagicMock()
|
||||
cx = SerialConnection("/dev/null", 115200)
|
||||
completed = asyncio.Event()
|
||||
|
||||
async def dummy():
|
||||
completed.set()
|
||||
|
||||
task = cx._spawn_background(dummy())
|
||||
assert task in cx._background_tasks
|
||||
await completed.wait()
|
||||
await asyncio.sleep(0)
|
||||
assert task not in cx._background_tasks
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
def test_event_dispatcher_spawn_background_retains_task(self):
|
||||
"""EventDispatcher _spawn_background adds task to _background_tasks."""
|
||||
async def _run():
|
||||
dispatcher = EventDispatcher()
|
||||
completed = asyncio.Event()
|
||||
|
||||
async def dummy():
|
||||
completed.set()
|
||||
|
||||
task = dispatcher._spawn_background(dummy())
|
||||
assert task in dispatcher._background_tasks
|
||||
await completed.wait()
|
||||
await asyncio.sleep(0)
|
||||
assert task not in dispatcher._background_tasks
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
def test_tcp_handle_rx_uses_tracked_task(self):
|
||||
"""TCP handle_rx dispatches reader.handle_rx via _spawn_background."""
|
||||
async def _run():
|
||||
cx = TCPConnection("127.0.0.1", 5555)
|
||||
reader = AsyncMock()
|
||||
reader.handle_rx = AsyncMock()
|
||||
cx.set_reader(reader)
|
||||
|
||||
# Build a minimal valid frame: 0x3e + 2-byte LE size + payload
|
||||
payload = b"\x01\x02\x03"
|
||||
size = len(payload).to_bytes(2, "little")
|
||||
frame = b"\x3e" + size + payload
|
||||
|
||||
cx.handle_rx(frame)
|
||||
# Task should be tracked
|
||||
assert len(cx._background_tasks) == 1
|
||||
# Let task complete
|
||||
await asyncio.sleep(0.05)
|
||||
reader.handle_rx.assert_awaited_once_with(payload)
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
def test_tcp_connection_lost_uses_tracked_task(self):
|
||||
"""TCP connection_lost dispatches disconnect callback via _spawn_background."""
|
||||
async def _run():
|
||||
cx = TCPConnection("127.0.0.1", 5555)
|
||||
callback = AsyncMock()
|
||||
cx.set_disconnect_callback(callback)
|
||||
|
||||
protocol = cx.MCClientProtocol(cx)
|
||||
protocol.connection_lost(None)
|
||||
|
||||
assert len(cx._background_tasks) == 1
|
||||
await asyncio.sleep(0.05)
|
||||
callback.assert_awaited_once_with("tcp_disconnect")
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
def test_gc_does_not_cancel_tracked_tasks(self):
|
||||
"""Tracked tasks survive GC pressure (the whole point of tracking)."""
|
||||
async def _run():
|
||||
cx = TCPConnection("127.0.0.1", 5555)
|
||||
result = []
|
||||
|
||||
async def slow_task():
|
||||
await asyncio.sleep(0.05)
|
||||
result.append("done")
|
||||
|
||||
cx._spawn_background(slow_task())
|
||||
# Force GC — untracked tasks could be collected here
|
||||
gc.collect()
|
||||
await asyncio.sleep(0.1)
|
||||
assert result == ["done"]
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
class TestTaskDoneCorrectness(unittest.TestCase):
|
||||
"""EventDispatcher.stop() must wait for in-flight async callbacks."""
|
||||
|
||||
def test_stop_waits_for_async_callbacks(self):
|
||||
"""stop() should not return until async callbacks have completed."""
|
||||
async def _run():
|
||||
dispatcher = EventDispatcher()
|
||||
await dispatcher.start()
|
||||
|
||||
callback_completed = False
|
||||
|
||||
async def slow_callback(event):
|
||||
nonlocal callback_completed
|
||||
await asyncio.sleep(0.1)
|
||||
callback_completed = True
|
||||
|
||||
dispatcher.subscribe(EventType.OK, slow_callback)
|
||||
await dispatcher.dispatch(Event(EventType.OK, {}))
|
||||
|
||||
# Give the dispatch loop a moment to pick up the event
|
||||
await asyncio.sleep(0.02)
|
||||
|
||||
# stop() should wait for slow_callback to finish
|
||||
await dispatcher.stop()
|
||||
assert callback_completed, "stop() returned before async callback completed"
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
class TestDeferredPrimitiveConstruction(unittest.TestCase):
|
||||
"""Queue and Lock must not bind to import-time loop."""
|
||||
|
||||
def test_event_dispatcher_queue_is_none_before_start(self):
|
||||
"""EventDispatcher.queue should be None until start() is called."""
|
||||
dispatcher = EventDispatcher()
|
||||
assert dispatcher.queue is None
|
||||
|
||||
def test_event_dispatcher_queue_created_on_start(self):
|
||||
"""start() creates the queue."""
|
||||
async def _run():
|
||||
dispatcher = EventDispatcher()
|
||||
assert dispatcher.queue is None
|
||||
await dispatcher.start()
|
||||
assert dispatcher.queue is not None
|
||||
assert isinstance(dispatcher.queue, asyncio.Queue)
|
||||
await dispatcher.stop()
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
def test_event_dispatcher_dispatch_before_start_raises(self):
|
||||
"""dispatch() before start() should raise RuntimeError."""
|
||||
async def _run():
|
||||
dispatcher = EventDispatcher()
|
||||
with self.assertRaises(RuntimeError):
|
||||
await dispatcher.dispatch(Event(EventType.OK, {}))
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
def test_command_handler_lock_is_none_before_use(self):
|
||||
"""CommandHandlerBase lock should be None until first access."""
|
||||
handler = CommandHandlerBase()
|
||||
assert handler._CommandHandlerBase__mesh_request_lock is None
|
||||
|
||||
def test_command_handler_lock_created_on_access(self):
|
||||
"""Accessing _mesh_request_lock creates it lazily."""
|
||||
async def _run():
|
||||
handler = CommandHandlerBase()
|
||||
lock = handler._mesh_request_lock
|
||||
assert isinstance(lock, asyncio.Lock)
|
||||
# Second access returns same instance
|
||||
assert handler._mesh_request_lock is lock
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
class TestGetRunningLoop(unittest.TestCase):
|
||||
"""get_event_loop() replaced with get_running_loop() in send()."""
|
||||
|
||||
def test_send_uses_get_running_loop(self):
|
||||
"""send() should call get_running_loop, not get_event_loop."""
|
||||
async def _run():
|
||||
handler = CommandHandlerBase()
|
||||
dispatcher = EventDispatcher()
|
||||
await dispatcher.start()
|
||||
handler.set_dispatcher(dispatcher)
|
||||
|
||||
mock_sender = AsyncMock()
|
||||
handler._sender_func = mock_sender
|
||||
|
||||
# Patch get_running_loop to verify it's called
|
||||
with patch("meshcore.commands.base.asyncio.get_running_loop", wraps=asyncio.get_running_loop) as mock_grl:
|
||||
# send with expected_events triggers the loop = asyncio.get_running_loop() path
|
||||
result = await handler.send(
|
||||
b"\x01",
|
||||
expected_events=[EventType.OK],
|
||||
timeout=0.05,
|
||||
)
|
||||
mock_grl.assert_called()
|
||||
|
||||
await dispatcher.stop()
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue