mirror of
https://github.com/meshcore-dev/meshcore_py.git
synced 2026-04-20 22:13:49 +00:00
Refactor to event system
This commit is contained in:
parent
8f0ecd7d75
commit
a5f1ec5c26
7 changed files with 66 additions and 271 deletions
|
|
@ -62,6 +62,9 @@ class BLEConnection:
|
|||
await self.client.start_notify(UART_TX_CHAR_UUID, self.handle_rx)
|
||||
|
||||
nus = self.client.services.get_service(UART_SERVICE_UUID)
|
||||
if nus is None:
|
||||
logger.error("Could not find UART service")
|
||||
return None
|
||||
self.rx_char = nus.get_characteristic(UART_RX_CHAR_UUID)
|
||||
|
||||
logger.info("BLE Connection started")
|
||||
|
|
@ -82,4 +85,10 @@ class BLEConnection:
|
|||
asyncio.create_task(self.reader.handle_rx(data))
|
||||
|
||||
async def send(self, data):
|
||||
if not self.client:
|
||||
logger.error("Client is not connected")
|
||||
return False
|
||||
if not self.rx_char:
|
||||
logger.error("RX characteristic not found")
|
||||
return False
|
||||
await self.client.write_gatt_char(self.rx_char, bytes(data), response=False)
|
||||
|
|
|
|||
|
|
@ -26,10 +26,13 @@ def deprecated(func):
|
|||
|
||||
|
||||
class CommandHandler:
|
||||
def __init__(self):
|
||||
DEFAULT_TIMEOUT = 5.0
|
||||
|
||||
def __init__(self, default_timeout=None):
|
||||
self._sender_func = None
|
||||
self._reader = None
|
||||
self.dispatcher = None
|
||||
self.default_timeout = default_timeout if default_timeout is not None else self.DEFAULT_TIMEOUT
|
||||
|
||||
def set_connection(self, connection):
|
||||
async def sender(data):
|
||||
|
|
@ -42,10 +45,13 @@ class CommandHandler:
|
|||
def set_dispatcher(self, dispatcher):
|
||||
self.dispatcher = dispatcher
|
||||
|
||||
async def send(self, data, expected_events=None, timeout=5.0):
|
||||
async def send(self, data, expected_events=None, timeout=None):
|
||||
if not self.dispatcher:
|
||||
raise RuntimeError("Dispatcher not set, cannot send commands")
|
||||
|
||||
# Use the provided timeout or fall back to default_timeout
|
||||
timeout = timeout if timeout is not None else self.default_timeout
|
||||
|
||||
if self._sender_func:
|
||||
logger.debug(f"Sending raw data: {data.hex() if isinstance(data, bytes) else data}")
|
||||
await self._sender_func(data)
|
||||
|
|
@ -163,15 +169,20 @@ class CommandHandler:
|
|||
data = b"\x0f" + key
|
||||
return await self.send(data, [EventType.OK, EventType.ERROR])
|
||||
|
||||
async def get_msg(self):
|
||||
async def get_msg(self, timeout=1):
|
||||
logger.debug("Requesting pending messages")
|
||||
return await self.send(b"\x0A", [EventType.CONTACT_MSG_RECV, EventType.CHANNEL_MSG_RECV, EventType.ERROR], 1)
|
||||
return await self.send(b"\x0A", [EventType.CONTACT_MSG_RECV, EventType.CHANNEL_MSG_RECV, EventType.ERROR], timeout)
|
||||
|
||||
async def send_login(self, dst, pwd):
|
||||
logger.debug(f"Sending login request to: {dst.hex() if isinstance(dst, bytes) else dst}")
|
||||
data = b"\x1a" + dst + pwd.encode("ascii")
|
||||
return await self.send(data, [EventType.MSG_SENT, EventType.ERROR])
|
||||
|
||||
async def send_logout(self, dst):
|
||||
self.login_resp = asyncio.Future()
|
||||
data = b"\x1d" + dst
|
||||
return await self.send(data, [EventType.MSG_SENT, EventType.ERROR])
|
||||
|
||||
async def send_statusreq(self, dst):
|
||||
logger.debug(f"Sending status request to: {dst.hex() if isinstance(dst, bytes) else dst}")
|
||||
data = b"\x1b" + dst
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ class EventType(Enum):
|
|||
class Event:
|
||||
type: EventType
|
||||
payload: Any
|
||||
attributes: Dict[str, Any] = None
|
||||
attributes: Dict[str, Any] = {}
|
||||
|
||||
def __post_init__(self):
|
||||
if self.attributes is None:
|
||||
|
|
@ -64,7 +64,7 @@ class EventDispatcher:
|
|||
self.running = False
|
||||
self._task = None
|
||||
|
||||
def subscribe(self, event_type: Union[EventType, None], callback: Callable[[Event], None]) -> Subscription:
|
||||
def subscribe(self, event_type: Union[EventType, None], callback: Callable[[Event], Union[None, asyncio.Future]]) -> Subscription:
|
||||
subscription = Subscription(self, event_type, callback)
|
||||
self.subscriptions.append(subscription)
|
||||
return subscription
|
||||
|
|
@ -83,7 +83,9 @@ class EventDispatcher:
|
|||
for subscription in self.subscriptions.copy():
|
||||
if subscription.event_type is None or subscription.event_type == event.type:
|
||||
try:
|
||||
await subscription.callback(event)
|
||||
result = subscription.callback(event)
|
||||
if asyncio.iscoroutine(result):
|
||||
await result
|
||||
except Exception as e:
|
||||
print(f"Error in event handler: {e}")
|
||||
|
||||
|
|
@ -106,13 +108,13 @@ class EventDispatcher:
|
|||
pass
|
||||
self._task = None
|
||||
|
||||
async def wait_for_event(self, event_type: EventType, timeout: float = None) -> Optional[Event]:
|
||||
async def wait_for_event(self, event_type: EventType, timeout: float | None = None) -> Optional[Event]:
|
||||
future = asyncio.Future()
|
||||
|
||||
async def event_handler(event: Event):
|
||||
def event_handler(event: Event):
|
||||
if not future.done():
|
||||
future.set_result(event)
|
||||
|
||||
|
||||
subscription = self.subscribe(event_type, event_handler)
|
||||
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -28,11 +28,11 @@ class MeshCore:
|
|||
"""
|
||||
Interface to a MeshCore device
|
||||
"""
|
||||
def __init__(self, cx, debug=False):
|
||||
def __init__(self, cx, debug=False, default_timeout=None):
|
||||
self.cx = cx
|
||||
self.dispatcher = EventDispatcher()
|
||||
self._reader = MessageReader(self.dispatcher)
|
||||
self.commands = CommandHandler()
|
||||
self.commands = CommandHandler(default_timeout=default_timeout)
|
||||
|
||||
# Set up logger
|
||||
if debug:
|
||||
|
|
@ -58,19 +58,19 @@ class MeshCore:
|
|||
cx.set_reader(self._reader)
|
||||
|
||||
@classmethod
|
||||
async def create_tcp(cls, host: str, port: int, debug: bool = False) -> 'MeshCore':
|
||||
async def create_tcp(cls, host: str, port: int, debug: bool = False, default_timeout=None) -> 'MeshCore':
|
||||
"""Create and connect a MeshCore instance using TCP connection"""
|
||||
from .tcp_cx import TCPConnection
|
||||
|
||||
connection = TCPConnection(host, port)
|
||||
await connection.connect()
|
||||
|
||||
mc = cls(connection, debug=debug)
|
||||
mc = cls(connection, debug=debug, default_timeout=default_timeout)
|
||||
await mc.connect()
|
||||
return mc
|
||||
|
||||
@classmethod
|
||||
async def create_serial(cls, port: str, baudrate: int = 115200, debug: bool = False) -> 'MeshCore':
|
||||
async def create_serial(cls, port: str, baudrate: int = 115200, debug: bool = False, default_timeout=None) -> 'MeshCore':
|
||||
"""Create and connect a MeshCore instance using serial connection"""
|
||||
from .serial_cx import SerialConnection
|
||||
import asyncio
|
||||
|
|
@ -79,12 +79,12 @@ class MeshCore:
|
|||
await connection.connect()
|
||||
await asyncio.sleep(0.1) # Time for transport to establish
|
||||
|
||||
mc = cls(connection, debug=debug)
|
||||
mc = cls(connection, debug=debug, default_timeout=default_timeout)
|
||||
await mc.connect()
|
||||
return mc
|
||||
|
||||
@classmethod
|
||||
async def create_ble(cls, address: Optional[str] = None, debug: bool = False) -> 'MeshCore':
|
||||
async def create_ble(cls, address: Optional[str] = None, debug: bool = False, default_timeout=None) -> 'MeshCore':
|
||||
"""Create and connect a MeshCore instance using BLE connection
|
||||
|
||||
If address is None, it will scan for and connect to the first available MeshCore device.
|
||||
|
|
@ -96,7 +96,7 @@ class MeshCore:
|
|||
if result is None:
|
||||
raise ConnectionError("Failed to connect to BLE device")
|
||||
|
||||
mc = cls(connection, debug=debug)
|
||||
mc = cls(connection, debug=debug, default_timeout=default_timeout)
|
||||
await mc.connect()
|
||||
return mc
|
||||
|
||||
|
|
@ -142,11 +142,15 @@ class MeshCore:
|
|||
|
||||
Args:
|
||||
event_type: Type of event to wait for, from EventType enum
|
||||
timeout: Maximum time to wait in seconds, or None for no timeout
|
||||
timeout: Maximum time to wait in seconds, or None to use default_timeout
|
||||
|
||||
Returns:
|
||||
Event object or None if timeout
|
||||
"""
|
||||
# Use the provided timeout or fall back to default_timeout
|
||||
if timeout is None:
|
||||
timeout = self.default_timeout
|
||||
|
||||
return await self.dispatcher.wait_for_event(event_type, timeout)
|
||||
|
||||
def _setup_data_tracking(self):
|
||||
|
|
@ -181,6 +185,16 @@ class MeshCore:
|
|||
"""Get the current device time"""
|
||||
return self._time
|
||||
|
||||
@property
|
||||
def default_timeout(self):
|
||||
"""Get the default timeout for commands"""
|
||||
return self.commands.default_timeout
|
||||
|
||||
@default_timeout.setter
|
||||
def default_timeout(self, value):
|
||||
"""Set the default timeout for commands"""
|
||||
self.commands.default_timeout = value
|
||||
|
||||
def get_contact_by_name(self, name):
|
||||
"""
|
||||
Find a contact by its name (adv_name field)
|
||||
|
|
@ -275,7 +289,7 @@ class MeshCore:
|
|||
if hasattr(self, '_auto_fetch_task') and self._auto_fetch_task and not self._auto_fetch_task.done():
|
||||
self._auto_fetch_task.cancel()
|
||||
try:
|
||||
await self._auto_fetch_task
|
||||
await self._auto_fetch_task # type: ignore
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._auto_fetch_task = None
|
||||
|
|
|
|||
|
|
@ -1,249 +0,0 @@
|
|||
import asyncio
|
||||
from typing import Dict, Any, Optional, Callable
|
||||
|
||||
from .events import EventDispatcher, MessageType, Event
|
||||
from .reader import MessageReader
|
||||
from .commands import CommandHandler, deprecated
|
||||
|
||||
|
||||
class MeshCore:
|
||||
def __init__(self, cx):
|
||||
self.cx = cx
|
||||
self.dispatcher = EventDispatcher()
|
||||
self._reader = MessageReader(self.dispatcher)
|
||||
self.commands = CommandHandler()
|
||||
|
||||
# Set up connections
|
||||
self.commands.set_connection(cx)
|
||||
|
||||
# Initialize state
|
||||
self.contacts = {}
|
||||
self.self_info = {}
|
||||
self.time = 0
|
||||
|
||||
# Set the message handler in the connection
|
||||
cx.set_mc(self)
|
||||
|
||||
async def connect(self):
|
||||
# Start the event dispatcher
|
||||
await self.dispatcher.start()
|
||||
|
||||
# Start the command handler
|
||||
await self.commands.start()
|
||||
|
||||
# Send the initial app start
|
||||
return await self.commands.send_appstart()
|
||||
|
||||
async def disconnect(self):
|
||||
# Stop the event dispatcher
|
||||
await self.dispatcher.stop()
|
||||
|
||||
# Stop the command handler
|
||||
await self.commands.stop()
|
||||
|
||||
# Internal method - called by the connection
|
||||
def handle_rx(self, data: bytearray):
|
||||
asyncio.create_task(self._reader.handle_rx(data))
|
||||
|
||||
# Expose subscribe/wait capabilities from the event system
|
||||
def subscribe(self, message_type, callback):
|
||||
return self.dispatcher.subscribe(message_type, callback)
|
||||
|
||||
async def wait_for_event(self, message_type, timeout=None):
|
||||
return await self.dispatcher.wait_for_event(message_type, timeout)
|
||||
|
||||
# Legacy method implementations that delegate to the command handler
|
||||
# using the deprecated decorator from commands.py
|
||||
|
||||
@deprecated
|
||||
async def send(self, data, timeout=5):
|
||||
return await self.commands.send(data, timeout)
|
||||
|
||||
@deprecated
|
||||
async def send_only(self, data):
|
||||
await self.commands.send_only(data)
|
||||
|
||||
@deprecated
|
||||
async def send_appstart(self):
|
||||
return await self.commands.send_appstart()
|
||||
|
||||
@deprecated
|
||||
async def send_device_query(self):
|
||||
return await self.commands.send_device_query()
|
||||
|
||||
@deprecated
|
||||
async def send_advert(self, flood=False):
|
||||
return await self.commands.send_advert(flood)
|
||||
|
||||
@deprecated
|
||||
async def set_name(self, name):
|
||||
return await self.commands.set_name(name)
|
||||
|
||||
@deprecated
|
||||
async def set_coords(self, lat, lon):
|
||||
return await self.commands.set_coords(lat, lon)
|
||||
|
||||
@deprecated
|
||||
async def reboot(self):
|
||||
return await self.commands.reboot()
|
||||
|
||||
@deprecated
|
||||
async def get_bat(self):
|
||||
return await self.commands.get_bat()
|
||||
|
||||
@deprecated
|
||||
async def get_time(self):
|
||||
time_result = await self.commands.get_time()
|
||||
if isinstance(time_result, int):
|
||||
self.time = time_result
|
||||
return self.time
|
||||
|
||||
@deprecated
|
||||
async def set_time(self, val):
|
||||
return await self.commands.set_time(val)
|
||||
|
||||
@deprecated
|
||||
async def set_tx_power(self, val):
|
||||
return await self.commands.set_tx_power(val)
|
||||
|
||||
@deprecated
|
||||
async def set_radio(self, freq, bw, sf, cr):
|
||||
return await self.commands.set_radio(freq, bw, sf, cr)
|
||||
|
||||
@deprecated
|
||||
async def set_tuning(self, rx_dly, af):
|
||||
return await self.commands.set_tuning(rx_dly, af)
|
||||
|
||||
@deprecated
|
||||
async def set_devicepin(self, pin):
|
||||
return await self.commands.set_devicepin(pin)
|
||||
|
||||
@deprecated
|
||||
async def get_contacts(self):
|
||||
await self.commands.get_contacts()
|
||||
contact_end = await self.dispatcher.wait_for_event(MessageType.CONTACT_END)
|
||||
if contact_end:
|
||||
self.contacts = contact_end.payload
|
||||
return self.contacts
|
||||
|
||||
@deprecated
|
||||
async def ensure_contacts(self):
|
||||
if not self.contacts:
|
||||
await self.get_contacts()
|
||||
|
||||
@deprecated
|
||||
async def reset_path(self, key):
|
||||
return await self.commands.reset_path(key)
|
||||
|
||||
@deprecated
|
||||
async def share_contact(self, key):
|
||||
return await self.commands.share_contact(key)
|
||||
|
||||
@deprecated
|
||||
async def export_contact(self, key=b""):
|
||||
return await self.commands.export_contact(key)
|
||||
|
||||
@deprecated
|
||||
async def remove_contact(self, key):
|
||||
return await self.commands.remove_contact(key)
|
||||
|
||||
@deprecated
|
||||
async def set_out_path(self, contact, path):
|
||||
contact["out_path"] = path
|
||||
contact["out_path_len"] = -1
|
||||
contact["out_path_len"] = int(len(path) / 2)
|
||||
|
||||
@deprecated
|
||||
async def update_contact(self, contact):
|
||||
out_path_hex = contact["out_path"]
|
||||
out_path_hex = out_path_hex + (128-len(out_path_hex)) * "0"
|
||||
adv_name_hex = contact["adv_name"].encode().hex()
|
||||
adv_name_hex = adv_name_hex + (64-len(adv_name_hex)) * "0"
|
||||
data = b"\x09" \
|
||||
+ bytes.fromhex(contact["public_key"])\
|
||||
+ contact["type"].to_bytes(1)\
|
||||
+ contact["flags"].to_bytes(1)\
|
||||
+ contact["out_path_len"].to_bytes(1, 'little', signed=True)\
|
||||
+ bytes.fromhex(out_path_hex)\
|
||||
+ bytes.fromhex(adv_name_hex)\
|
||||
+ contact["last_advert"].to_bytes(4, 'little')\
|
||||
+ int(contact["adv_lat"]*1e6).to_bytes(4, 'little', signed=True)\
|
||||
+ int(contact["adv_lon"]*1e6).to_bytes(4, 'little', signed=True)
|
||||
return await self.send(data)
|
||||
|
||||
@deprecated
|
||||
async def send_login(self, dst, pwd):
|
||||
await self.commands.send_login(dst, pwd)
|
||||
login_event = await self.dispatcher.wait_for_event(MessageType.LOGIN_SUCCESS, 0.1)
|
||||
if login_event:
|
||||
return True
|
||||
return await self.commands.send_login(dst, pwd)
|
||||
|
||||
@deprecated
|
||||
async def wait_login(self, timeout=5):
|
||||
login_event = await self.dispatcher.wait_for_event(MessageType.LOGIN_SUCCESS, timeout)
|
||||
if login_event:
|
||||
return True
|
||||
login_failed = await self.dispatcher.wait_for_event(MessageType.LOGIN_FAILED, 0)
|
||||
if login_failed:
|
||||
return False
|
||||
return False
|
||||
|
||||
@deprecated
|
||||
async def send_statusreq(self, dst):
|
||||
await self.commands.send_statusreq(dst)
|
||||
|
||||
@deprecated
|
||||
async def wait_status(self, timeout=5):
|
||||
status_event = await self.dispatcher.wait_for_event(MessageType.STATUS_RESPONSE, timeout)
|
||||
if status_event:
|
||||
return status_event.payload
|
||||
return False
|
||||
|
||||
@deprecated
|
||||
async def send_cmd(self, dst, cmd):
|
||||
timestamp = await self.get_time()
|
||||
return await self.commands.send_cmd(dst, cmd, timestamp.to_bytes(4, 'little'))
|
||||
|
||||
@deprecated
|
||||
async def send_msg(self, dst, msg):
|
||||
timestamp = await self.get_time()
|
||||
result = await self.commands.send_msg(dst, msg, timestamp.to_bytes(4, 'little'))
|
||||
return result
|
||||
|
||||
@deprecated
|
||||
async def send_chan_msg(self, chan, msg):
|
||||
timestamp = await self.get_time()
|
||||
return await self.commands.send_chan_msg(chan, msg, timestamp.to_bytes(4, 'little'))
|
||||
|
||||
@deprecated
|
||||
async def get_msg(self):
|
||||
await self.commands.get_msg()
|
||||
|
||||
# Wait for any message type that could be received
|
||||
message_types = [
|
||||
MessageType.CONTACT_MSG_RECV,
|
||||
MessageType.CHANNEL_MSG_RECV,
|
||||
MessageType.NO_MORE_MSGS
|
||||
]
|
||||
|
||||
for msg_type in message_types:
|
||||
event = await self.dispatcher.wait_for_event(msg_type, 0)
|
||||
if event:
|
||||
return event.payload
|
||||
|
||||
return False
|
||||
|
||||
@deprecated
|
||||
async def wait_msg(self, timeout=-1):
|
||||
msg_event = await self.dispatcher.wait_for_event(MessageType.MESSAGES_WAITING, timeout)
|
||||
return msg_event is not None
|
||||
|
||||
@deprecated
|
||||
async def wait_ack(self, timeout=6):
|
||||
ack_event = await self.dispatcher.wait_for_event(MessageType.ACK, timeout)
|
||||
return ack_event is not None
|
||||
|
||||
@deprecated
|
||||
async def send_cli(self, cmd):
|
||||
return await self.commands.send_cli(cmd)
|
||||
|
|
@ -15,6 +15,7 @@ class SerialConnection:
|
|||
self.baudrate = baudrate
|
||||
self.frame_started = False
|
||||
self.frame_size = 0
|
||||
self.transport = None
|
||||
self.header = b""
|
||||
self.inframe = b""
|
||||
|
||||
|
|
@ -25,7 +26,8 @@ class SerialConnection:
|
|||
def connection_made(self, transport):
|
||||
self.cx.transport = transport
|
||||
logger.debug('port opened')
|
||||
transport.serial.rts = False # You can manipulate Serial object via transport
|
||||
if isinstance(transport, serial_asyncio.SerialTransport) and transport.serial:
|
||||
transport.serial.rts = False # You can manipulate Serial object via transport
|
||||
|
||||
def data_received(self, data):
|
||||
self.cx.handle_rx(data)
|
||||
|
|
@ -79,6 +81,9 @@ class SerialConnection:
|
|||
self.handle_rx(data[self.frame_size-framelen:])
|
||||
|
||||
async def send(self, data):
|
||||
if not self.transport:
|
||||
logger.error("Transport not connected, cannot send data")
|
||||
return
|
||||
size = len(data)
|
||||
pkt = b"\x3c" + size.to_bytes(2, byteorder="little") + data
|
||||
logger.debug(f"sending pkt : {pkt}")
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ class TCPConnection:
|
|||
self.header = b""
|
||||
self.inframe = b""
|
||||
|
||||
class MCClientProtocol:
|
||||
class MCClientProtocol(asyncio.Protocol):
|
||||
def __init__(self, cx):
|
||||
self.cx = cx
|
||||
|
||||
|
|
@ -76,6 +76,9 @@ class TCPConnection:
|
|||
self.handle_rx(data[self.frame_size-framelen:])
|
||||
|
||||
async def send(self, data):
|
||||
if not self.transport:
|
||||
logger.error("Transport not connected, cannot send data")
|
||||
return
|
||||
size = len(data)
|
||||
pkt = b"\x3c" + size.to_bytes(2, byteorder="little") + data
|
||||
logger.debug(f"sending pkt : {pkt}")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue