diff --git a/src/meshcore/commands/base.py b/src/meshcore/commands/base.py index 9e0f00e..a927882 100644 --- a/src/meshcore/commands/base.py +++ b/src/meshcore/commands/base.py @@ -90,6 +90,14 @@ class CommandHandlerBase: expected_events: Optional[Union[EventType, List[EventType]]] = None, timeout: Optional[float] = None, ) -> Event: + """Wait for the first of *expected_events* to arrive. + + Returns the first matched ``Event``. When ``EventType.ERROR`` is + among the expected types, the caller **must** check + ``result.is_error()`` before accessing command-specific payload + keys — an ERROR payload is ``{"reason": "..."}`` and will + ``KeyError`` on any other key. + """ try: # Convert single event to list if needed if not isinstance(expected_events, list): @@ -129,9 +137,6 @@ class CommandHandlerBase: logger.debug(f"Command error: {e}") return Event(EventType.ERROR, {"error": str(e)}) - return Event(EventType.ERROR, {}) - - async def send( self, data: bytes, @@ -151,7 +156,14 @@ class CommandHandlerBase: timeout: Timeout in seconds, or None to use default_timeout Returns: - Event: The full event object that was received in response to the command + Event: The full event object that was received in response to + the command. + + Important: + When ``EventType.ERROR`` is included in *expected_events*, the + returned event may be an error response. Callers **must** + check ``result.is_error()`` before accessing command-specific + payload keys to avoid ``KeyError``. """ if not self.dispatcher: raise RuntimeError("Dispatcher not set, cannot send commands") @@ -266,6 +278,7 @@ class CommandHandlerBase: contact = self._get_contact_by_prefix(dst_bytes.hex()) # need a contact for return path if contact is None: logger.error("No contact found") + return Event(EventType.ERROR, {"reason": "contact_not_found"}) zero_hop = False if contact["out_path_len"] == -1: diff --git a/src/meshcore/commands/device.py b/src/meshcore/commands/device.py index af986db..e5d2b5d 100644 --- a/src/meshcore/commands/device.py +++ b/src/meshcore/commands/device.py @@ -13,7 +13,7 @@ class DeviceCommands(CommandHandlerBase): async def send_appstart(self) -> Event: logger.debug("Sending appstart command") b1 = bytearray(b"\x01\x03 mccli") - return await self.send(b1, [EventType.SELF_INFO]) + return await self.send(b1, [EventType.SELF_INFO, EventType.ERROR]) async def send_device_query(self) -> Event: logger.debug("Sending device query command") @@ -129,32 +129,50 @@ class DeviceCommands(CommandHandlerBase): return await self.send(data, [EventType.OK, EventType.ERROR]) async def set_telemetry_mode_base(self, telemetry_mode_base: int) -> Event: - infos = (await self.send_appstart()).payload + result = await self.send_appstart() + if result.is_error(): + return result + infos = result.payload infos["telemetry_mode_base"] = telemetry_mode_base return await self.set_other_params_from_infos(infos) async def set_telemetry_mode_loc(self, telemetry_mode_loc: int) -> Event: - infos = (await self.send_appstart()).payload + result = await self.send_appstart() + if result.is_error(): + return result + infos = result.payload infos["telemetry_mode_loc"] = telemetry_mode_loc return await self.set_other_params_from_infos(infos) async def set_telemetry_mode_env(self, telemetry_mode_env: int) -> Event: - infos = (await self.send_appstart()).payload + result = await self.send_appstart() + if result.is_error(): + return result + infos = result.payload infos["telemetry_mode_env"] = telemetry_mode_env return await self.set_other_params_from_infos(infos) async def set_manual_add_contacts(self, manual_add_contacts: bool) -> Event: - infos = (await self.send_appstart()).payload + result = await self.send_appstart() + if result.is_error(): + return result + infos = result.payload infos["manual_add_contacts"] = manual_add_contacts return await self.set_other_params_from_infos(infos) async def set_advert_loc_policy(self, advert_loc_policy: int) -> Event: - infos = (await self.send_appstart()).payload + result = await self.send_appstart() + if result.is_error(): + return result + infos = result.payload infos["adv_loc_policy"] = advert_loc_policy return await self.set_other_params_from_infos(infos) async def set_multi_acks(self, multi_acks: int) -> Event: - infos = (await self.send_appstart()).payload + result = await self.send_appstart() + if result.is_error(): + return result + infos = result.payload infos["multi_acks"] = multi_acks return await self.set_other_params_from_infos(infos) diff --git a/src/meshcore/commands/messaging.py b/src/meshcore/commands/messaging.py index b266ae0..b821ea0 100644 --- a/src/meshcore/commands/messaging.py +++ b/src/meshcore/commands/messaging.py @@ -144,8 +144,12 @@ class MessagingCommands(CommandHandlerBase): logger.info(f"Retry sending msg: {attempts + 1}") result = await self.send_msg(dst, msg, timestamp, attempt=attempts) - if result.type == EventType.ERROR: - logger.error(f"⚠️ Failed to send message: {result.payload}") + if result.is_error(): + logger.error(f"Failed to send message: {result.payload}") + attempts += 1 + if flood: + flood_attempts += 1 + continue exp_ack = result.payload["expected_ack"].hex() timeout = result.payload["suggested_timeout"] / 1000 * 1.2 if timeout==0 else timeout @@ -255,7 +259,7 @@ class MessagingCommands(CommandHandlerBase): elif path_hash_len == 8 : flags = 3 else : - logger.error(f"Invalid path format: {e}") + logger.error(f"Invalid path format: unknown path_hash_len {path_hash_len}") return Event(EventType.ERROR, {"reason": "invalid_path_format"}) else: flags = 0 diff --git a/src/meshcore/events.py b/src/meshcore/events.py index f8a7521..258dc7f 100644 --- a/src/meshcore/events.py +++ b/src/meshcore/events.py @@ -104,6 +104,17 @@ class Event: if kwargs: self.attributes.update(kwargs) + def is_error(self) -> bool: + """Return True if this event represents an error response. + + Callers that include ``EventType.ERROR`` in their expected-events + list **must** check ``result.is_error()`` (or ``result.type == + EventType.ERROR``) before accessing keyed payload fields, because + an ERROR payload contains ``{"reason": "..."}`` — not the + command-specific keys the caller expects on the happy path. + """ + return self.type == EventType.ERROR + def clone(self): """ Create a copy of the event. diff --git a/tests/unit/test_error_handling.py b/tests/unit/test_error_handling.py new file mode 100644 index 0000000..7f88e28 --- /dev/null +++ b/tests/unit/test_error_handling.py @@ -0,0 +1,236 @@ +"""Verification tests for error response handling fixes. + +The tests confirm that error responses are surfaced cleanly instead +of causing KeyError, TypeError, NameError, or silent fallthrough. +""" +import asyncio +import pytest +from unittest.mock import MagicMock, AsyncMock, patch + +from meshcore.commands import CommandHandler +from meshcore.events import EventType, Event, Subscription + +pytestmark = pytest.mark.asyncio + +VALID_PUBKEY_HEX = "0123456789abcdef" * 4 # 64 hex chars = 32 bytes + + +# ── Fixtures ─────────────────────────────────────────────────────── + +@pytest.fixture +def mock_connection(): + connection = MagicMock() + connection.send = AsyncMock() + return connection + + +@pytest.fixture +def mock_dispatcher(): + dispatcher = MagicMock() + dispatcher.wait_for_event = AsyncMock() + dispatcher.dispatch = AsyncMock() + + def fake_subscribe(event_type, handler, attribute_filters=None): + sub = MagicMock(spec=Subscription) + sub.unsubscribe = MagicMock() + dispatcher._last_subscribe_handler = handler + dispatcher._last_subscribe_event_type = event_type + return sub + + dispatcher.subscribe = MagicMock(side_effect=fake_subscribe) + return dispatcher + + +@pytest.fixture +def command_handler(mock_connection, mock_dispatcher): + handler = CommandHandler() + + async def sender(data): + await mock_connection.send(data) + + handler._sender_func = sender + handler.dispatcher = mock_dispatcher + return handler + + +def setup_error_response(mock_dispatcher): + """Configure dispatcher to return an ERROR event for any subscribe.""" + def fake_subscribe(evt_type, handler, attr_filters=None): + sub = MagicMock(spec=Subscription) + sub.unsubscribe = MagicMock() + # Always fire ERROR regardless of which event type was subscribed + if evt_type == EventType.ERROR: + asyncio.get_event_loop().call_soon( + handler, Event(EventType.ERROR, {"reason": "test_error"}) + ) + return sub + + mock_dispatcher.subscribe = MagicMock(side_effect=fake_subscribe) + + +def setup_event_response(mock_dispatcher, event_type, payload): + """Configure dispatcher to return a specific event.""" + def fake_subscribe(evt_type, handler, attr_filters=None): + sub = MagicMock(spec=Subscription) + sub.unsubscribe = MagicMock() + if evt_type == event_type: + asyncio.get_event_loop().call_soon( + handler, Event(event_type, payload) + ) + return sub + + mock_dispatcher.subscribe = MagicMock(side_effect=fake_subscribe) + + +# ── Event.is_error() helper ────────────────────────────────── + +async def test_event_is_error_true(): + """is_error() returns True for ERROR events.""" + event = Event(EventType.ERROR, {"reason": "test"}) + assert event.is_error() is True + + +async def test_event_is_error_false(): + """is_error() returns False for non-ERROR events.""" + event = Event(EventType.OK, {}) + assert event.is_error() is False + event2 = Event(EventType.SELF_INFO, {"name": "test"}) + assert event2.is_error() is False + + +# ── send_msg_with_retry continues on ERROR ────────────── + +async def test_send_msg_with_retry_error_no_keyerror( + command_handler, mock_dispatcher +): + """send_msg_with_retry returns None (exhausted retries) on + persistent ERROR instead of raising KeyError on missing 'expected_ack'.""" + setup_error_response(mock_dispatcher) + + # Provide a mock contact so the path logic doesn't interfere + command_handler._get_contact_by_prefix = MagicMock(return_value=None) + + # max_attempts=2 so it retries once then gives up + result = await command_handler.send_msg_with_retry( + VALID_PUBKEY_HEX, "hello", max_attempts=2, timeout=0.1 + ) + + # Should return None (no ACK received) rather than raising KeyError + assert result is None + + +# ── send_appstart includes ERROR in expected events ────────── + +async def test_send_appstart_returns_error( + command_handler, mock_dispatcher +): + """send_appstart returns ERROR event instead of hanging on timeout.""" + setup_error_response(mock_dispatcher) + + result = await command_handler.send_appstart() + + assert result.type == EventType.ERROR + assert result.is_error() is True + assert result.payload["reason"] == "test_error" + + +# ── device setters return ERROR from send_appstart ─────────── + +async def test_set_telemetry_mode_base_error( + command_handler, mock_dispatcher +): + """set_telemetry_mode_base returns ERROR instead of KeyError.""" + setup_error_response(mock_dispatcher) + + result = await command_handler.set_telemetry_mode_base(1) + + assert result.is_error() + assert result.payload["reason"] == "test_error" + + +async def test_set_telemetry_mode_loc_error( + command_handler, mock_dispatcher +): + """set_telemetry_mode_loc returns ERROR instead of KeyError.""" + setup_error_response(mock_dispatcher) + + result = await command_handler.set_telemetry_mode_loc(1) + + assert result.is_error() + + +async def test_set_telemetry_mode_env_error( + command_handler, mock_dispatcher +): + """set_telemetry_mode_env returns ERROR instead of KeyError.""" + setup_error_response(mock_dispatcher) + + result = await command_handler.set_telemetry_mode_env(1) + + assert result.is_error() + + +async def test_set_manual_add_contacts_error( + command_handler, mock_dispatcher +): + """set_manual_add_contacts returns ERROR instead of KeyError.""" + setup_error_response(mock_dispatcher) + + result = await command_handler.set_manual_add_contacts(True) + + assert result.is_error() + + +async def test_set_advert_loc_policy_error( + command_handler, mock_dispatcher +): + """set_advert_loc_policy returns ERROR instead of KeyError.""" + setup_error_response(mock_dispatcher) + + result = await command_handler.set_advert_loc_policy(1) + + assert result.is_error() + + +async def test_set_multi_acks_error( + command_handler, mock_dispatcher +): + """set_multi_acks returns ERROR instead of KeyError.""" + setup_error_response(mock_dispatcher) + + result = await command_handler.set_multi_acks(1) + + assert result.is_error() + + +# ── send_anon_req returns ERROR on contact not found ───────── + +async def test_send_anon_req_contact_not_found( + command_handler, mock_dispatcher +): + """send_anon_req returns ERROR event when contact prefix not found, + instead of raising TypeError on NoneType subscript.""" + command_handler._get_contact_by_prefix = MagicMock(return_value=None) + + result = await command_handler.send_anon_req( + VALID_PUBKEY_HEX, MagicMock(value=1) + ) + + assert result.is_error() + assert result.payload["reason"] == "contact_not_found" + + +# ── send_trace handles unknown path_hash_len without NameError ── + +async def test_send_trace_unknown_path_hash_len( + command_handler, mock_connection, mock_dispatcher +): + """send_trace with a path whose segments don't match any known + path_hash_len returns ERROR cleanly instead of NameError on 'e'.""" + # 5-char hex segments → path_hash_len = 2.5 → doesn't match 1,2,4,8 + result = await command_handler.send_trace( + auth_code=0, tag=1, flags=None, path="abcde" + ) + + assert result.is_error() + assert result.payload["reason"] == "invalid_path_format"