diff --git a/README.md b/README.md index bea2c49..c87cfc6 100644 --- a/README.md +++ b/README.md @@ -26,8 +26,11 @@ async def main(): # Send a message to the first contact if contacts: - contact_key = next(iter(contacts.items()))[1]['public_key'] - await meshcore.commands.send_msg(bytes.fromhex(contact_key), "Hello from Python!") + # Get the first contact + contact = next(iter(contacts.items()))[1] + + # Pass the contact object directly to send_msg + await meshcore.commands.send_msg(contact, "Hello from Python!") await meshcore.disconnect() @@ -247,15 +250,31 @@ This logs detailed information about commands sent and events received. ### Sending Messages to Contacts +Commands that require a destination (`send_msg`, `send_login`, `send_statusreq`, etc.) now accept either: +- A string with the hex representation of a public key +- A contact object with a "public_key" field +- Bytes object (for backward compatibility) + ```python # Get contacts and send to a specific one contacts = await meshcore.commands.get_contacts() for key, contact in contacts.items(): if contact["adv_name"] == "Alice": - # Convert the hex key to bytes + # Option 1: Pass the contact object directly + await meshcore.commands.send_msg(contact, "Hello Alice!") + + # Option 2: Use the public key string + await meshcore.commands.send_msg(contact["public_key"], "Hello again Alice!") + + # Option 3 (backward compatible): Convert the hex key to bytes dst_key = bytes.fromhex(contact["public_key"]) - await meshcore.commands.send_msg(dst_key, "Hello Alice!") + await meshcore.commands.send_msg(dst_key, "Hello once more Alice!") break + +# You can also directly use a contact found by name +contact = meshcore.get_contact_by_name("Bob") +if contact: + await meshcore.commands.send_msg(contact, "Hello Bob!") ``` ### Monitoring Channel Messages diff --git a/examples/ble_t1000_chan_msg.py b/examples/ble_t1000_chan_msg.py index 996f718..0eacf7a 100755 --- a/examples/ble_t1000_chan_msg.py +++ b/examples/ble_t1000_chan_msg.py @@ -15,6 +15,6 @@ async def main () : mc = MeshCore(con) await mc.connect() - await mc.send_chan_msg(0, MSG) + await mc.commands.send_chan_msg(0, MSG) asyncio.run(main()) diff --git a/examples/ble_t1000_msg.py b/examples/ble_t1000_msg.py index c9ec543..d190cca 100755 --- a/examples/ble_t1000_msg.py +++ b/examples/ble_t1000_msg.py @@ -14,7 +14,11 @@ async def main () : mc = MeshCore(con) await mc.connect() - await mc.get_contacts() - await mc.commands.send_msg(bytes.fromhex(mc.contacts[DEST]["public_key"])[0:6],MSG) + await mc.ensure_contacts() + contact = mc.get_contact_by_name(DEST) + if contact is None: + print(f"Contact '{DEST}' not found in contacts.") + return + await mc.commands.send_msg(contact,MSG) asyncio.run(main()) diff --git a/examples/serial_msg.py b/examples/serial_msg.py index cb21a10..a5a055d 100755 --- a/examples/serial_msg.py +++ b/examples/serial_msg.py @@ -38,7 +38,7 @@ async def main(): # Send the message and get the MSG_SENT event print(f"Sending message: '{args.message}'") send_result = await mc.commands.send_msg( - bytes.fromhex(contact["public_key"])[0:6], + contact, args.message ) diff --git a/examples/serial_repeater_status.py b/examples/serial_repeater_status.py index 2a77336..134f8e9 100755 --- a/examples/serial_repeater_status.py +++ b/examples/serial_repeater_status.py @@ -16,7 +16,10 @@ async def main () : await mc.commands.get_contacts() repeater = mc.get_contact_by_name(REPEATER) - await mc.commands.send_login(bytes.fromhex(repeater["public_key"]), PASSWORD) + if repeater is None: + print(f"Repeater '{REPEATER}' not found in contacts.") + return + await mc.commands.send_login(repeater, PASSWORD) print("Login sent ... awaiting") diff --git a/examples/tcp_mchome_msg.py b/examples/tcp_mchome_msg.py index ec44d38..cf8762f 100755 --- a/examples/tcp_mchome_msg.py +++ b/examples/tcp_mchome_msg.py @@ -17,6 +17,10 @@ async def main () : await mc.connect() await mc.ensure_contacts() - await mc.commands.send_msg(bytes.fromhex(mc.get_contact_by_name(DEST)["public_key"])[0:6],MSG) + contact = mc.get_contact_by_name(DEST) + if contact is None: + print(f"Contact '{DEST}' not found in contacts.") + return + await mc.commands.send_msg(contact ,MSG) asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml index 601dfed..10deec2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,9 @@ license = "MIT" license-files = ["LICEN[CS]E*"] dependencies = [ "bleak", "pyserial-asyncio" ] +[project.optional-dependencies] +dev = ["pytest", "pytest-asyncio"] + [project.urls] Homepage = "https://github.com/fdlamotte/meshcore_py" Issues = "https://github.com/fdlamotte/meshcore_py/issues" diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..d280de0 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +asyncio_mode = auto \ No newline at end of file diff --git a/src/meshcore/commands.py b/src/meshcore/commands.py index 56e0f4b..85d9f9e 100644 --- a/src/meshcore/commands.py +++ b/src/meshcore/commands.py @@ -1,32 +1,72 @@ import asyncio import logging -from typing import Any, Dict +from typing import Any, Dict, List, Optional, Union from .events import EventType import random + +# Define types for destination parameters +DestinationType = Union[bytes, str, Dict[str, Any]] logger = logging.getLogger("meshcore") +def _validate_destination(dst: DestinationType, prefix_length: int = 6) -> bytes: + """ + Validates and converts a destination to a bytes object. + + Args: + dst: The destination, which can be: + - str: Hex string representation of a public key + - dict: Contact object with a "public_key" field + prefix_length: The length of the prefix to use (default: 6 bytes) + + Returns: + bytes: The destination public key as a bytes object + + Raises: + ValueError: If dst is invalid or doesn't contain required fields + """ + if isinstance(dst, bytes): + # Already bytes, use directly + return dst[:prefix_length] + elif isinstance(dst, str): + # Hex string, convert to bytes + try: + return bytes.fromhex(dst)[:prefix_length] + except ValueError: + raise ValueError(f"Invalid public key hex string: {dst}") + elif isinstance(dst, dict): + # Contact object, extract public_key + if "public_key" not in dst: + raise ValueError("Contact object must have a 'public_key' field") + try: + return bytes.fromhex(dst["public_key"])[:prefix_length] + except ValueError: + raise ValueError(f"Invalid public_key in contact: {dst['public_key']}") + else: + raise ValueError(f"Destination must be a public key string or contact object, got: {type(dst)}") + class CommandHandler: DEFAULT_TIMEOUT = 5.0 - def __init__(self, default_timeout=None): + def __init__(self, default_timeout: Optional[float] = None): self._sender_func = None self._reader = None self.dispatcher = None self.default_timeout = default_timeout if default_timeout is not None else self.DEFAULT_TIMEOUT - def set_connection(self, connection): - async def sender(data): + def set_connection(self, connection: Any) -> None: + async def sender(data: bytes) -> None: await connection.send(data) self._sender_func = sender - def set_reader(self, reader): + def set_reader(self, reader: Any) -> None: self._reader = reader - def set_dispatcher(self, dispatcher): + def set_dispatcher(self, dispatcher: Any) -> None: self.dispatcher = dispatcher - async def send(self, data, expected_events=None, timeout=None) -> Dict[str, Any]: + async def send(self, data: bytes, expected_events: Optional[Union[EventType, List[EventType]]] = None, + timeout: Optional[float] = None) -> Dict[str, Any]: """ Send a command and wait for expected event responses. @@ -70,54 +110,54 @@ class CommandHandler: return {"success": True} - async def send_appstart(self): + async def send_appstart(self) -> Dict[str, Any]: logger.debug("Sending appstart command") b1 = bytearray(b'\x01\x03 mccli') return await self.send(b1, [EventType.SELF_INFO]) - async def send_device_query(self): + async def send_device_query(self) -> Dict[str, Any]: logger.debug("Sending device query command") return await self.send(b"\x16\x03", [EventType.DEVICE_INFO, EventType.ERROR]) - async def send_advert(self, flood=False): + async def send_advert(self, flood: bool = False) -> Dict[str, Any]: logger.debug(f"Sending advertisement command (flood={flood})") if flood: return await self.send(b"\x07\x01", [EventType.OK, EventType.ERROR]) else: return await self.send(b"\x07", [EventType.OK, EventType.ERROR]) - async def set_name(self, name): + async def set_name(self, name: str) -> Dict[str, Any]: logger.debug(f"Setting device name to: {name}") return await self.send(b'\x08' + name.encode("ascii"), [EventType.OK, EventType.ERROR]) - async def set_coords(self, lat, lon): + async def set_coords(self, lat: float, lon: float) -> Dict[str, Any]: logger.debug(f"Setting coordinates to: lat={lat}, lon={lon}") return await self.send(b'\x0e'\ + int(lat*1e6).to_bytes(4, 'little', signed=True)\ + int(lon*1e6).to_bytes(4, 'little', signed=True)\ + int(0).to_bytes(4, 'little'), [EventType.OK, EventType.ERROR]) - async def reboot(self): + async def reboot(self) -> Dict[str, Any]: logger.debug("Sending reboot command") return await self.send(b'\x13reboot') - async def get_bat(self): + async def get_bat(self) -> Dict[str, Any]: logger.debug("Getting battery information") return await self.send(b'\x14', [EventType.BATTERY, EventType.ERROR]) - async def get_time(self): + async def get_time(self) -> Dict[str, Any]: logger.debug("Getting device time") return await self.send(b"\x05", [EventType.CURRENT_TIME, EventType.ERROR]) - async def set_time(self, val): + async def set_time(self, val: int) -> Dict[str, Any]: logger.debug(f"Setting device time to: {val}") return await self.send(b"\x06" + int(val).to_bytes(4, 'little'), [EventType.OK, EventType.ERROR]) - async def set_tx_power(self, val): + async def set_tx_power(self, val: int) -> Dict[str, Any]: logger.debug(f"Setting TX power to: {val}") return await self.send(b"\x0c" + int(val).to_bytes(4, 'little'), [EventType.OK, EventType.ERROR]) - async def set_radio(self, freq, bw, sf, cr): + async def set_radio(self, freq: float, bw: float, sf: int, cr: int) -> Dict[str, Any]: logger.debug(f"Setting radio params: freq={freq}, bw={bw}, sf={sf}, cr={cr}") return await self.send(b"\x0b" \ + int(float(freq)*1000).to_bytes(4, 'little')\ @@ -125,7 +165,7 @@ class CommandHandler: + int(sf).to_bytes(1, 'little')\ + int(cr).to_bytes(1, 'little'), [EventType.OK, EventType.ERROR]) - async def set_tuning(self, rx_dly, af): + async def set_tuning(self, rx_dly: int, af: int) -> Dict[str, Any]: logger.debug(f"Setting tuning params: rx_dly={rx_dly}, af={af}") return await self.send(b"\x15" \ + int(rx_dly).to_bytes(4, 'little')\ @@ -133,74 +173,85 @@ class CommandHandler: + int(0).to_bytes(1, 'little')\ + int(0).to_bytes(1, 'little'), [EventType.OK, EventType.ERROR]) - async def set_devicepin(self, pin): + async def set_devicepin(self, pin: int) -> Dict[str, Any]: logger.debug(f"Setting device PIN to: {pin}") return await self.send(b"\x25" \ + int(pin).to_bytes(4, 'little'), [EventType.OK, EventType.ERROR]) - async def get_contacts(self): + async def get_contacts(self) -> Dict[str, Any]: logger.debug("Getting contacts") return await self.send(b"\x04", [EventType.CONTACTS, EventType.ERROR]) - async def reset_path(self, key): - logger.debug(f"Resetting path for contact: {key.hex() if isinstance(key, bytes) else key}") - data = b"\x0D" + key + async def reset_path(self, key: DestinationType) -> Dict[str, Any]: + key_bytes = _validate_destination(key) + logger.debug(f"Resetting path for contact: {key_bytes.hex()}") + data = b"\x0D" + key_bytes return await self.send(data, [EventType.OK, EventType.ERROR]) - async def share_contact(self, key): - logger.debug(f"Sharing contact: {key.hex() if isinstance(key, bytes) else key}") - data = b"\x10" + key + async def share_contact(self, key: DestinationType) -> Dict[str, Any]: + key_bytes = _validate_destination(key) + logger.debug(f"Sharing contact: {key_bytes.hex()}") + data = b"\x10" + key_bytes return await self.send(data, [EventType.CONTACT_SHARE, EventType.ERROR]) - async def export_contact(self, key=b""): - logger.debug(f"Exporting contact: {key.hex() if key else 'all'}") - data = b"\x11" + key + async def export_contact(self, key: Optional[DestinationType] = None) -> Dict[str, Any]: + if key: + key_bytes = _validate_destination(key) + logger.debug(f"Exporting contact: {key_bytes.hex()}") + data = b"\x11" + key_bytes + else: + logger.debug("Exporting all contacts") + data = b"\x11" return await self.send(data, [EventType.OK, EventType.ERROR]) - async def remove_contact(self, key): - logger.debug(f"Removing contact: {key.hex() if isinstance(key, bytes) else key}") - data = b"\x0f" + key + async def remove_contact(self, key: DestinationType) -> Dict[str, Any]: + key_bytes = _validate_destination(key) + logger.debug(f"Removing contact: {key_bytes.hex()}") + data = b"\x0f" + key_bytes return await self.send(data, [EventType.OK, EventType.ERROR]) - async def get_msg(self, timeout=1): + async def get_msg(self, timeout: Optional[float] = 1) -> Dict[str, Any]: logger.debug("Requesting pending messages") return await self.send(b"\x0A", [EventType.CONTACT_MSG_RECV, EventType.CHANNEL_MSG_RECV, EventType.ERROR], timeout) - async def send_login(self, dst, pwd): - logger.debug(f"Sending login request to: {dst.hex() if isinstance(dst, bytes) else dst}") - data = b"\x1a" + dst + pwd.encode("ascii") + async def send_login(self, dst: DestinationType, pwd: str) -> Dict[str, Any]: + dst_bytes = _validate_destination(dst) + logger.debug(f"Sending login request to: {dst_bytes.hex()}") + data = b"\x1a" + dst_bytes + pwd.encode("ascii") return await self.send(data, [EventType.MSG_SENT, EventType.ERROR]) - async def send_logout(self, dst): + async def send_logout(self, dst: DestinationType) -> Dict[str, Any]: + dst_bytes = _validate_destination(dst) self.login_resp = asyncio.Future() - data = b"\x1d" + dst + data = b"\x1d" + dst_bytes return await self.send(data, [EventType.MSG_SENT, EventType.ERROR]) - async def send_statusreq(self, dst): - logger.debug(f"Sending status request to: {dst.hex() if isinstance(dst, bytes) else dst}") - data = b"\x1b" + dst + async def send_statusreq(self, dst: DestinationType) -> Dict[str, Any]: + dst_bytes = _validate_destination(dst) + logger.debug(f"Sending status request to: {dst_bytes.hex()}") + data = b"\x1b" + dst_bytes return await self.send(data, [EventType.MSG_SENT, EventType.ERROR]) - async def send_cmd(self, dst, cmd, timestamp=None): - logger.debug(f"Sending command to {dst.hex() if isinstance(dst, bytes) else dst}: {cmd}") + async def send_cmd(self, dst: DestinationType, cmd: str, timestamp: Optional[int] = None) -> Dict[str, Any]: + dst_bytes = _validate_destination(dst) + logger.debug(f"Sending command to {dst_bytes.hex()}: {cmd}") - # Default to current time if timestamp not provided if timestamp is None: import time - timestamp = int(time.time()).to_bytes(4, 'little') + timestamp = int(time.time()) - data = b"\x02\x01\x00" + timestamp + dst + cmd.encode("ascii") + data = b"\x02\x01\x00" + timestamp.to_bytes(4, 'little') + dst_bytes + cmd.encode("ascii") return await self.send(data, [EventType.OK, EventType.ERROR]) - async def send_msg(self, dst, msg, timestamp=None): - logger.debug(f"Sending message to {dst.hex() if isinstance(dst, bytes) else dst}: {msg}") + async def send_msg(self, dst: DestinationType, msg: str, timestamp: Optional[int] = None) -> Dict[str, Any]: + dst_bytes = _validate_destination(dst) + logger.debug(f"Sending message to {dst_bytes.hex()}: {msg}") - # Default to current time if timestamp not provided if timestamp is None: import time - timestamp = int(time.time()).to_bytes(4, 'little') + timestamp = int(time.time()) - data = b"\x02\x00\x00" + timestamp + dst + msg.encode("ascii") + data = b"\x02\x00\x00" + timestamp.to_bytes(4, 'little') + dst_bytes + msg.encode("ascii") return await self.send(data, [EventType.MSG_SENT, EventType.ERROR]) async def send_chan_msg(self, chan, msg, timestamp=None): @@ -219,7 +270,8 @@ class CommandHandler: data = b"\x32" + cmd.encode('ascii') return await self.send(data, [EventType.CLI_RESPONSE, EventType.ERROR]) - async def send_trace(self, auth_code=0, tag=None, flags=0, path=None): + async def send_trace(self, auth_code: int = 0, tag: Optional[int] = None, + flags: int = 0, path: Optional[Union[str, bytes, bytearray]] = None) -> Dict[str, Any]: """ Send a trace packet to test routing through specific repeaters diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 0000000..d047299 --- /dev/null +++ b/tests/README.md @@ -0,0 +1,25 @@ +# MeshCore Tests + +## Running Tests + +To run the tests, first install the development dependencies: + +```bash +pip install -e ".[dev]" +``` + +Then run the tests using pytest: + +```bash +# Run all tests +pytest + +# Run tests with verbose output +pytest -v + +# Run a specific test file +pytest tests/unit/test_commands.py + +# Run a specific test +pytest tests/unit/test_commands.py::test_send_msg +``` \ No newline at end of file diff --git a/tests/unit/test_commands.py b/tests/unit/test_commands.py new file mode 100644 index 0000000..9b90562 --- /dev/null +++ b/tests/unit/test_commands.py @@ -0,0 +1,234 @@ +import pytest +import asyncio +from unittest.mock import MagicMock, patch, AsyncMock +from meshcore.commands import CommandHandler +from meshcore.events import EventType, Event + +pytestmark = pytest.mark.asyncio + +# Fixtures +@pytest.fixture +def mock_connection(): + connection = MagicMock() + connection.send = AsyncMock() + return connection + +@pytest.fixture +def mock_dispatcher(): + dispatcher = MagicMock() + dispatcher.wait_for_event = AsyncMock() + dispatcher.dispatch = AsyncMock() + return dispatcher + +@pytest.fixture +def command_handler(mock_connection, mock_dispatcher): + handler = CommandHandler() + + async def sender(data): + await mock_connection.send(data) + handler._sender_func = sender + + handler.dispatcher = mock_dispatcher + return handler + +# Test helper +def setup_event_response(mock_dispatcher, event_type, payload, attribute_filters=None): + async def wait_response(requested_type, filters=None, timeout=None): + if requested_type == event_type: + if filters and attribute_filters: + if not all(attribute_filters.get(key) == value for key, value in filters.items()): + return None + return Event(event_type, payload) + return None + + mock_dispatcher.wait_for_event.side_effect = wait_response + +# Basic tests +async def test_send_basic(command_handler, mock_connection): + result = await command_handler.send(b"test_data") + mock_connection.send.assert_called_once_with(b"test_data") + assert result == {"success": True} + +async def test_send_with_event(command_handler, mock_connection, mock_dispatcher): + expected_payload = {"success": True, "value": 42} + setup_event_response(mock_dispatcher, EventType.OK, expected_payload) + + result = await command_handler.send(b"test_command", [EventType.OK]) + + mock_connection.send.assert_called_once_with(b"test_command") + assert result == expected_payload + +async def test_send_timeout(command_handler, mock_connection, mock_dispatcher): + mock_dispatcher.wait_for_event.side_effect = asyncio.TimeoutError + + result = await command_handler.send(b"test_command", [EventType.OK], timeout=0.1) + assert result == {"success": False, "reason": "timeout"} + +# Destination validation tests +async def test_validate_destination_bytes(command_handler, mock_connection): + dst = b"123456789012" # 12 bytes + await command_handler.send_msg(dst, "test message") + + assert mock_connection.send.call_args[0][0].startswith(b"\x02\x00\x00") + assert b"123456" in mock_connection.send.call_args[0][0] + +async def test_validate_destination_hex_string(command_handler, mock_connection): + dst = "0123456789abcdef" + await command_handler.send_msg(dst, "test message") + + assert mock_connection.send.call_args[0][0].startswith(b"\x02\x00\x00") + assert b"\x01\x23\x45\x67\x89\xab" in mock_connection.send.call_args[0][0] + +async def test_validate_destination_contact_object(command_handler, mock_connection): + dst = {"public_key": "0123456789abcdef", "adv_name": "Test Contact"} + await command_handler.send_msg(dst, "test message") + + assert mock_connection.send.call_args[0][0].startswith(b"\x02\x00\x00") + assert b"\x01\x23\x45\x67\x89\xab" in mock_connection.send.call_args[0][0] + +# Command tests +async def test_send_login(command_handler, mock_connection): + await command_handler.send_login("0123456789abcdef", "password") + + assert mock_connection.send.call_args[0][0].startswith(b"\x1a") + assert b"\x01\x23\x45\x67\x89\xab" in mock_connection.send.call_args[0][0] + assert b"password" in mock_connection.send.call_args[0][0] + +async def test_send_msg(command_handler, mock_connection): + await command_handler.send_msg("0123456789abcdef", "hello") + + assert mock_connection.send.call_args[0][0].startswith(b"\x02\x00\x00") + assert b"\x01\x23\x45\x67\x89\xab" in mock_connection.send.call_args[0][0] + assert b"hello" in mock_connection.send.call_args[0][0] + +async def test_send_cmd(command_handler, mock_connection): + await command_handler.send_cmd("0123456789abcdef", "test_cmd") + + assert mock_connection.send.call_args[0][0].startswith(b"\x02\x01\x00") + assert b"\x01\x23\x45\x67\x89\xab" in mock_connection.send.call_args[0][0] + assert b"test_cmd" in mock_connection.send.call_args[0][0] + +# Device settings tests +async def test_set_name(command_handler, mock_connection): + await command_handler.set_name("Test Device") + + assert mock_connection.send.call_args[0][0].startswith(b"\x08") + assert b"Test Device" in mock_connection.send.call_args[0][0] + +async def test_set_coords(command_handler, mock_connection): + await command_handler.set_coords(37.7749, -122.4194) + + assert mock_connection.send.call_args[0][0].startswith(b"\x0e") + # Could add more detailed assertions for the byte encoding + +async def test_send_appstart(command_handler, mock_connection): + await command_handler.send_appstart() + assert mock_connection.send.call_args[0][0].startswith(b"\x01\x03") + assert b"mccli" in mock_connection.send.call_args[0][0] + +async def test_send_device_query(command_handler, mock_connection): + await command_handler.send_device_query() + assert mock_connection.send.call_args[0][0].startswith(b"\x16\x03") + +async def test_send_advert(command_handler, mock_connection): + # Test without flood + await command_handler.send_advert(flood=False) + assert mock_connection.send.call_args[0][0] == b"\x07" + + # Test with flood + mock_connection.reset_mock() + await command_handler.send_advert(flood=True) + assert mock_connection.send.call_args[0][0] == b"\x07\x01" + +async def test_reboot(command_handler, mock_connection): + await command_handler.reboot() + assert mock_connection.send.call_args[0][0].startswith(b"\x13reboot") + +async def test_get_bat(command_handler, mock_connection): + await command_handler.get_bat() + assert mock_connection.send.call_args[0][0].startswith(b"\x14") + +async def test_get_time(command_handler, mock_connection): + await command_handler.get_time() + assert mock_connection.send.call_args[0][0].startswith(b"\x05") + +async def test_set_time(command_handler, mock_connection): + timestamp = 1620000000 # Example timestamp + await command_handler.set_time(timestamp) + assert mock_connection.send.call_args[0][0].startswith(b"\x06") + +async def test_set_tx_power(command_handler, mock_connection): + await command_handler.set_tx_power(20) + assert mock_connection.send.call_args[0][0].startswith(b"\x0c") + +async def test_get_contacts(command_handler, mock_connection): + await command_handler.get_contacts() + assert mock_connection.send.call_args[0][0].startswith(b"\x04") + +async def test_reset_path(command_handler, mock_connection): + dst = "0123456789abcdef" + await command_handler.reset_path(dst) + assert mock_connection.send.call_args[0][0].startswith(b"\x0D") + assert b"\x01\x23\x45\x67\x89\xab" in mock_connection.send.call_args[0][0] + +async def test_share_contact(command_handler, mock_connection): + dst = "0123456789abcdef" + await command_handler.share_contact(dst) + assert mock_connection.send.call_args[0][0].startswith(b"\x10") + assert b"\x01\x23\x45\x67\x89\xab" in mock_connection.send.call_args[0][0] + +async def test_export_contact(command_handler, mock_connection): + # Test exporting all contacts + await command_handler.export_contact() + assert mock_connection.send.call_args[0][0] == b"\x11" + + # Test exporting specific contact + mock_connection.reset_mock() + dst = "0123456789abcdef" + await command_handler.export_contact(dst) + assert mock_connection.send.call_args[0][0].startswith(b"\x11") + assert b"\x01\x23\x45\x67\x89\xab" in mock_connection.send.call_args[0][0] + +async def test_remove_contact(command_handler, mock_connection): + dst = "0123456789abcdef" + await command_handler.remove_contact(dst) + assert mock_connection.send.call_args[0][0].startswith(b"\x0f") + assert b"\x01\x23\x45\x67\x89\xab" in mock_connection.send.call_args[0][0] + +async def test_get_msg(command_handler, mock_connection): + await command_handler.get_msg() + assert mock_connection.send.call_args[0][0].startswith(b"\x0A") + + # Test with custom timeout + mock_connection.reset_mock() + await command_handler.get_msg(timeout=5.0) + assert mock_connection.send.call_args[0][0].startswith(b"\x0A") + +async def test_send_logout(command_handler, mock_connection): + dst = "0123456789abcdef" + await command_handler.send_logout(dst) + assert mock_connection.send.call_args[0][0].startswith(b"\x1d") + assert b"\x01\x23\x45\x67\x89\xab" in mock_connection.send.call_args[0][0] + +async def test_send_statusreq(command_handler, mock_connection): + dst = "0123456789abcdef" + await command_handler.send_statusreq(dst) + assert mock_connection.send.call_args[0][0].startswith(b"\x1b") + assert b"\x01\x23\x45\x67\x89\xab" in mock_connection.send.call_args[0][0] + +async def test_send_trace(command_handler, mock_connection): + # Test with minimal parameters + await command_handler.send_trace() + first_call = mock_connection.send.call_args[0][0] + assert first_call.startswith(b"\x24") # 36 in decimal = 0x24 in hex + + # Test with all parameters + mock_connection.reset_mock() + await command_handler.send_trace( + auth_code=12345, + tag=67890, + flags=1, + path="01,23,45" + ) + second_call = mock_connection.send.call_args[0][0] + assert second_call.startswith(b"\x24") \ No newline at end of file diff --git a/tests/unit/test_events.py b/tests/unit/test_events.py new file mode 100644 index 0000000..a2925f6 --- /dev/null +++ b/tests/unit/test_events.py @@ -0,0 +1,112 @@ +import pytest +import asyncio +from unittest.mock import MagicMock, AsyncMock +from meshcore.events import EventDispatcher, EventType, Event + +pytestmark = pytest.mark.asyncio + +@pytest.fixture +def dispatcher(): + return EventDispatcher() + +async def test_subscribe_with_attribute_filter(dispatcher): + callback = MagicMock() + + # Subscribe with attribute filters + subscription = dispatcher.subscribe( + EventType.MSG_SENT, + callback, + attribute_filters={"type": 1, "expected_ack": "1234"} + ) + + # Start the dispatcher + await dispatcher.start() + + try: + # Dispatch event that should NOT match (wrong type) + await dispatcher.dispatch(Event( + EventType.MSG_SENT, + {"some": "data"}, + {"type": 2, "expected_ack": "1234"} + )) + await asyncio.sleep(0.1) # Allow processing + + # Callback should NOT have been called + assert callback.call_count == 0 + + # Dispatch event that should match all filters + await dispatcher.dispatch(Event( + EventType.MSG_SENT, + {"some": "data"}, + {"type": 1, "expected_ack": "1234"} + )) + await asyncio.sleep(0.1) # Allow processing + + # Callback should have been called once + assert callback.call_count == 1 + + finally: + await dispatcher.stop() + +async def test_wait_for_event_with_attribute_filter(dispatcher): + await dispatcher.start() + + try: + future_event = asyncio.create_task( + dispatcher.wait_for_event( + EventType.ACK, + attribute_filters={"code": "1234"}, + timeout=3.0 + ) + ) + + await asyncio.sleep(0.1) + + await dispatcher.dispatch(Event( + EventType.ACK, + {"some": "data"}, + {"code": "5678"} + )) + + await asyncio.sleep(0.1) + + await dispatcher.dispatch(Event( + EventType.ACK, + {"ack": "data"}, + {"code": "1234"} + )) + + result = await asyncio.wait_for(future_event, 3.0) + + assert result is not None + assert result.type == EventType.ACK + assert result.attributes["code"] == "1234" + assert result.payload == {"ack": "data"} + + finally: + await dispatcher.stop() + +async def test_wait_for_event_timeout_with_filter(dispatcher): + await dispatcher.start() + + try: + # Wait for an event that won't arrive + result = await dispatcher.wait_for_event( + EventType.ACK, + attribute_filters={"code": "1234"}, + timeout=0.1 + ) + + # Should get None due to timeout + assert result is None + + finally: + await dispatcher.stop() + +async def test_event_init_with_kwargs(): + # Test creating an event with keyword attributes + event = Event(EventType.ACK, {"data": "value"}, code="1234", status="ok") + + assert event.type == EventType.ACK + assert event.payload == {"data": "value"} + assert event.attributes == {"code": "1234", "status": "ok"} \ No newline at end of file