Refactor to event system

This commit is contained in:
Alex Wolden 2025-04-08 22:56:16 -07:00
parent 8f0ecd7d75
commit a5f1ec5c26
7 changed files with 66 additions and 271 deletions

View file

@ -62,6 +62,9 @@ class BLEConnection:
await self.client.start_notify(UART_TX_CHAR_UUID, self.handle_rx)
nus = self.client.services.get_service(UART_SERVICE_UUID)
if nus is None:
logger.error("Could not find UART service")
return None
self.rx_char = nus.get_characteristic(UART_RX_CHAR_UUID)
logger.info("BLE Connection started")
@ -82,4 +85,10 @@ class BLEConnection:
asyncio.create_task(self.reader.handle_rx(data))
async def send(self, data):
if not self.client:
logger.error("Client is not connected")
return False
if not self.rx_char:
logger.error("RX characteristic not found")
return False
await self.client.write_gatt_char(self.rx_char, bytes(data), response=False)

View file

@ -26,10 +26,13 @@ def deprecated(func):
class CommandHandler:
def __init__(self):
DEFAULT_TIMEOUT = 5.0
def __init__(self, default_timeout=None):
self._sender_func = None
self._reader = None
self.dispatcher = None
self.default_timeout = default_timeout if default_timeout is not None else self.DEFAULT_TIMEOUT
def set_connection(self, connection):
async def sender(data):
@ -42,10 +45,13 @@ class CommandHandler:
def set_dispatcher(self, dispatcher):
self.dispatcher = dispatcher
async def send(self, data, expected_events=None, timeout=5.0):
async def send(self, data, expected_events=None, timeout=None):
if not self.dispatcher:
raise RuntimeError("Dispatcher not set, cannot send commands")
# Use the provided timeout or fall back to default_timeout
timeout = timeout if timeout is not None else self.default_timeout
if self._sender_func:
logger.debug(f"Sending raw data: {data.hex() if isinstance(data, bytes) else data}")
await self._sender_func(data)
@ -163,15 +169,20 @@ class CommandHandler:
data = b"\x0f" + key
return await self.send(data, [EventType.OK, EventType.ERROR])
async def get_msg(self):
async def get_msg(self, timeout=1):
logger.debug("Requesting pending messages")
return await self.send(b"\x0A", [EventType.CONTACT_MSG_RECV, EventType.CHANNEL_MSG_RECV, EventType.ERROR], 1)
return await self.send(b"\x0A", [EventType.CONTACT_MSG_RECV, EventType.CHANNEL_MSG_RECV, EventType.ERROR], timeout)
async def send_login(self, dst, pwd):
logger.debug(f"Sending login request to: {dst.hex() if isinstance(dst, bytes) else dst}")
data = b"\x1a" + dst + pwd.encode("ascii")
return await self.send(data, [EventType.MSG_SENT, EventType.ERROR])
async def send_logout(self, dst):
self.login_resp = asyncio.Future()
data = b"\x1d" + dst
return await self.send(data, [EventType.MSG_SENT, EventType.ERROR])
async def send_statusreq(self, dst):
logger.debug(f"Sending status request to: {dst.hex() if isinstance(dst, bytes) else dst}")
data = b"\x1b" + dst

View file

@ -40,7 +40,7 @@ class EventType(Enum):
class Event:
type: EventType
payload: Any
attributes: Dict[str, Any] = None
attributes: Dict[str, Any] = {}
def __post_init__(self):
if self.attributes is None:
@ -64,7 +64,7 @@ class EventDispatcher:
self.running = False
self._task = None
def subscribe(self, event_type: Union[EventType, None], callback: Callable[[Event], None]) -> Subscription:
def subscribe(self, event_type: Union[EventType, None], callback: Callable[[Event], Union[None, asyncio.Future]]) -> Subscription:
subscription = Subscription(self, event_type, callback)
self.subscriptions.append(subscription)
return subscription
@ -83,7 +83,9 @@ class EventDispatcher:
for subscription in self.subscriptions.copy():
if subscription.event_type is None or subscription.event_type == event.type:
try:
await subscription.callback(event)
result = subscription.callback(event)
if asyncio.iscoroutine(result):
await result
except Exception as e:
print(f"Error in event handler: {e}")
@ -106,13 +108,13 @@ class EventDispatcher:
pass
self._task = None
async def wait_for_event(self, event_type: EventType, timeout: float = None) -> Optional[Event]:
async def wait_for_event(self, event_type: EventType, timeout: float | None = None) -> Optional[Event]:
future = asyncio.Future()
async def event_handler(event: Event):
def event_handler(event: Event):
if not future.done():
future.set_result(event)
subscription = self.subscribe(event_type, event_handler)
try:

View file

@ -28,11 +28,11 @@ class MeshCore:
"""
Interface to a MeshCore device
"""
def __init__(self, cx, debug=False):
def __init__(self, cx, debug=False, default_timeout=None):
self.cx = cx
self.dispatcher = EventDispatcher()
self._reader = MessageReader(self.dispatcher)
self.commands = CommandHandler()
self.commands = CommandHandler(default_timeout=default_timeout)
# Set up logger
if debug:
@ -58,19 +58,19 @@ class MeshCore:
cx.set_reader(self._reader)
@classmethod
async def create_tcp(cls, host: str, port: int, debug: bool = False) -> 'MeshCore':
async def create_tcp(cls, host: str, port: int, debug: bool = False, default_timeout=None) -> 'MeshCore':
"""Create and connect a MeshCore instance using TCP connection"""
from .tcp_cx import TCPConnection
connection = TCPConnection(host, port)
await connection.connect()
mc = cls(connection, debug=debug)
mc = cls(connection, debug=debug, default_timeout=default_timeout)
await mc.connect()
return mc
@classmethod
async def create_serial(cls, port: str, baudrate: int = 115200, debug: bool = False) -> 'MeshCore':
async def create_serial(cls, port: str, baudrate: int = 115200, debug: bool = False, default_timeout=None) -> 'MeshCore':
"""Create and connect a MeshCore instance using serial connection"""
from .serial_cx import SerialConnection
import asyncio
@ -79,12 +79,12 @@ class MeshCore:
await connection.connect()
await asyncio.sleep(0.1) # Time for transport to establish
mc = cls(connection, debug=debug)
mc = cls(connection, debug=debug, default_timeout=default_timeout)
await mc.connect()
return mc
@classmethod
async def create_ble(cls, address: Optional[str] = None, debug: bool = False) -> 'MeshCore':
async def create_ble(cls, address: Optional[str] = None, debug: bool = False, default_timeout=None) -> 'MeshCore':
"""Create and connect a MeshCore instance using BLE connection
If address is None, it will scan for and connect to the first available MeshCore device.
@ -96,7 +96,7 @@ class MeshCore:
if result is None:
raise ConnectionError("Failed to connect to BLE device")
mc = cls(connection, debug=debug)
mc = cls(connection, debug=debug, default_timeout=default_timeout)
await mc.connect()
return mc
@ -142,11 +142,15 @@ class MeshCore:
Args:
event_type: Type of event to wait for, from EventType enum
timeout: Maximum time to wait in seconds, or None for no timeout
timeout: Maximum time to wait in seconds, or None to use default_timeout
Returns:
Event object or None if timeout
"""
# Use the provided timeout or fall back to default_timeout
if timeout is None:
timeout = self.default_timeout
return await self.dispatcher.wait_for_event(event_type, timeout)
def _setup_data_tracking(self):
@ -181,6 +185,16 @@ class MeshCore:
"""Get the current device time"""
return self._time
@property
def default_timeout(self):
"""Get the default timeout for commands"""
return self.commands.default_timeout
@default_timeout.setter
def default_timeout(self, value):
"""Set the default timeout for commands"""
self.commands.default_timeout = value
def get_contact_by_name(self, name):
"""
Find a contact by its name (adv_name field)
@ -275,7 +289,7 @@ class MeshCore:
if hasattr(self, '_auto_fetch_task') and self._auto_fetch_task and not self._auto_fetch_task.done():
self._auto_fetch_task.cancel()
try:
await self._auto_fetch_task
await self._auto_fetch_task # type: ignore
except asyncio.CancelledError:
pass
self._auto_fetch_task = None

View file

@ -1,249 +0,0 @@
import asyncio
from typing import Dict, Any, Optional, Callable
from .events import EventDispatcher, MessageType, Event
from .reader import MessageReader
from .commands import CommandHandler, deprecated
class MeshCore:
def __init__(self, cx):
self.cx = cx
self.dispatcher = EventDispatcher()
self._reader = MessageReader(self.dispatcher)
self.commands = CommandHandler()
# Set up connections
self.commands.set_connection(cx)
# Initialize state
self.contacts = {}
self.self_info = {}
self.time = 0
# Set the message handler in the connection
cx.set_mc(self)
async def connect(self):
# Start the event dispatcher
await self.dispatcher.start()
# Start the command handler
await self.commands.start()
# Send the initial app start
return await self.commands.send_appstart()
async def disconnect(self):
# Stop the event dispatcher
await self.dispatcher.stop()
# Stop the command handler
await self.commands.stop()
# Internal method - called by the connection
def handle_rx(self, data: bytearray):
asyncio.create_task(self._reader.handle_rx(data))
# Expose subscribe/wait capabilities from the event system
def subscribe(self, message_type, callback):
return self.dispatcher.subscribe(message_type, callback)
async def wait_for_event(self, message_type, timeout=None):
return await self.dispatcher.wait_for_event(message_type, timeout)
# Legacy method implementations that delegate to the command handler
# using the deprecated decorator from commands.py
@deprecated
async def send(self, data, timeout=5):
return await self.commands.send(data, timeout)
@deprecated
async def send_only(self, data):
await self.commands.send_only(data)
@deprecated
async def send_appstart(self):
return await self.commands.send_appstart()
@deprecated
async def send_device_query(self):
return await self.commands.send_device_query()
@deprecated
async def send_advert(self, flood=False):
return await self.commands.send_advert(flood)
@deprecated
async def set_name(self, name):
return await self.commands.set_name(name)
@deprecated
async def set_coords(self, lat, lon):
return await self.commands.set_coords(lat, lon)
@deprecated
async def reboot(self):
return await self.commands.reboot()
@deprecated
async def get_bat(self):
return await self.commands.get_bat()
@deprecated
async def get_time(self):
time_result = await self.commands.get_time()
if isinstance(time_result, int):
self.time = time_result
return self.time
@deprecated
async def set_time(self, val):
return await self.commands.set_time(val)
@deprecated
async def set_tx_power(self, val):
return await self.commands.set_tx_power(val)
@deprecated
async def set_radio(self, freq, bw, sf, cr):
return await self.commands.set_radio(freq, bw, sf, cr)
@deprecated
async def set_tuning(self, rx_dly, af):
return await self.commands.set_tuning(rx_dly, af)
@deprecated
async def set_devicepin(self, pin):
return await self.commands.set_devicepin(pin)
@deprecated
async def get_contacts(self):
await self.commands.get_contacts()
contact_end = await self.dispatcher.wait_for_event(MessageType.CONTACT_END)
if contact_end:
self.contacts = contact_end.payload
return self.contacts
@deprecated
async def ensure_contacts(self):
if not self.contacts:
await self.get_contacts()
@deprecated
async def reset_path(self, key):
return await self.commands.reset_path(key)
@deprecated
async def share_contact(self, key):
return await self.commands.share_contact(key)
@deprecated
async def export_contact(self, key=b""):
return await self.commands.export_contact(key)
@deprecated
async def remove_contact(self, key):
return await self.commands.remove_contact(key)
@deprecated
async def set_out_path(self, contact, path):
contact["out_path"] = path
contact["out_path_len"] = -1
contact["out_path_len"] = int(len(path) / 2)
@deprecated
async def update_contact(self, contact):
out_path_hex = contact["out_path"]
out_path_hex = out_path_hex + (128-len(out_path_hex)) * "0"
adv_name_hex = contact["adv_name"].encode().hex()
adv_name_hex = adv_name_hex + (64-len(adv_name_hex)) * "0"
data = b"\x09" \
+ bytes.fromhex(contact["public_key"])\
+ contact["type"].to_bytes(1)\
+ contact["flags"].to_bytes(1)\
+ contact["out_path_len"].to_bytes(1, 'little', signed=True)\
+ bytes.fromhex(out_path_hex)\
+ bytes.fromhex(adv_name_hex)\
+ contact["last_advert"].to_bytes(4, 'little')\
+ int(contact["adv_lat"]*1e6).to_bytes(4, 'little', signed=True)\
+ int(contact["adv_lon"]*1e6).to_bytes(4, 'little', signed=True)
return await self.send(data)
@deprecated
async def send_login(self, dst, pwd):
await self.commands.send_login(dst, pwd)
login_event = await self.dispatcher.wait_for_event(MessageType.LOGIN_SUCCESS, 0.1)
if login_event:
return True
return await self.commands.send_login(dst, pwd)
@deprecated
async def wait_login(self, timeout=5):
login_event = await self.dispatcher.wait_for_event(MessageType.LOGIN_SUCCESS, timeout)
if login_event:
return True
login_failed = await self.dispatcher.wait_for_event(MessageType.LOGIN_FAILED, 0)
if login_failed:
return False
return False
@deprecated
async def send_statusreq(self, dst):
await self.commands.send_statusreq(dst)
@deprecated
async def wait_status(self, timeout=5):
status_event = await self.dispatcher.wait_for_event(MessageType.STATUS_RESPONSE, timeout)
if status_event:
return status_event.payload
return False
@deprecated
async def send_cmd(self, dst, cmd):
timestamp = await self.get_time()
return await self.commands.send_cmd(dst, cmd, timestamp.to_bytes(4, 'little'))
@deprecated
async def send_msg(self, dst, msg):
timestamp = await self.get_time()
result = await self.commands.send_msg(dst, msg, timestamp.to_bytes(4, 'little'))
return result
@deprecated
async def send_chan_msg(self, chan, msg):
timestamp = await self.get_time()
return await self.commands.send_chan_msg(chan, msg, timestamp.to_bytes(4, 'little'))
@deprecated
async def get_msg(self):
await self.commands.get_msg()
# Wait for any message type that could be received
message_types = [
MessageType.CONTACT_MSG_RECV,
MessageType.CHANNEL_MSG_RECV,
MessageType.NO_MORE_MSGS
]
for msg_type in message_types:
event = await self.dispatcher.wait_for_event(msg_type, 0)
if event:
return event.payload
return False
@deprecated
async def wait_msg(self, timeout=-1):
msg_event = await self.dispatcher.wait_for_event(MessageType.MESSAGES_WAITING, timeout)
return msg_event is not None
@deprecated
async def wait_ack(self, timeout=6):
ack_event = await self.dispatcher.wait_for_event(MessageType.ACK, timeout)
return ack_event is not None
@deprecated
async def send_cli(self, cmd):
return await self.commands.send_cli(cmd)

View file

@ -15,6 +15,7 @@ class SerialConnection:
self.baudrate = baudrate
self.frame_started = False
self.frame_size = 0
self.transport = None
self.header = b""
self.inframe = b""
@ -25,7 +26,8 @@ class SerialConnection:
def connection_made(self, transport):
self.cx.transport = transport
logger.debug('port opened')
transport.serial.rts = False # You can manipulate Serial object via transport
if isinstance(transport, serial_asyncio.SerialTransport) and transport.serial:
transport.serial.rts = False # You can manipulate Serial object via transport
def data_received(self, data):
self.cx.handle_rx(data)
@ -79,6 +81,9 @@ class SerialConnection:
self.handle_rx(data[self.frame_size-framelen:])
async def send(self, data):
if not self.transport:
logger.error("Transport not connected, cannot send data")
return
size = len(data)
pkt = b"\x3c" + size.to_bytes(2, byteorder="little") + data
logger.debug(f"sending pkt : {pkt}")

View file

@ -18,7 +18,7 @@ class TCPConnection:
self.header = b""
self.inframe = b""
class MCClientProtocol:
class MCClientProtocol(asyncio.Protocol):
def __init__(self, cx):
self.cx = cx
@ -76,6 +76,9 @@ class TCPConnection:
self.handle_rx(data[self.frame_size-framelen:])
async def send(self, data):
if not self.transport:
logger.error("Transport not connected, cannot send data")
return
size = len(data)
pkt = b"\x3c" + size.to_bytes(2, byteorder="little") + data
logger.debug(f"sending pkt : {pkt}")