feat: add ML-based adaptive timeout prediction using LinearRegressor

Train a linear regression model on actual message delivery times to
predict tighter timeouts, replacing worst-case physics estimates.
Features: path length, message bytes, seconds since last RX, flood mode.
Global model with per-contact blending after 10+ observations per contact.
Falls back to existing physics formula when model has insufficient data.
This commit is contained in:
zjs81 2026-03-14 16:56:11 -07:00
parent 8b280b37be
commit 2ee2358ecc
9 changed files with 683 additions and 20 deletions

View file

@ -19,6 +19,7 @@ import '../services/message_retry_service.dart';
import '../services/path_history_service.dart';
import '../services/app_settings_service.dart';
import '../services/background_service.dart';
import '../services/timeout_prediction_service.dart';
import '../services/notification_service.dart';
import 'meshcore_connector_usb.dart';
import 'meshcore_connector_tcp.dart';
@ -166,6 +167,8 @@ class MeshCoreConnector extends ChangeNotifier {
bool _isLoadingContacts = false;
bool _isLoadingChannels = false;
bool _hasLoadedChannels = false;
TimeoutPredictionService? _timeoutPredictionService;
DateTime _lastRxTime = DateTime.now();
bool _batteryRequested = false;
bool _awaitingSelfInfo = false;
bool _hasReceivedDeviceInfo = false;
@ -668,6 +671,7 @@ class MeshCoreConnector extends ChangeNotifier {
BleDebugLogService? bleDebugLogService,
AppDebugLogService? appDebugLogService,
BackgroundService? backgroundService,
TimeoutPredictionService? timeoutPredictionService,
}) {
_retryService = retryService;
_pathHistoryService = pathHistoryService;
@ -675,6 +679,7 @@ class MeshCoreConnector extends ChangeNotifier {
_bleDebugLogService = bleDebugLogService;
_appDebugLogService = appDebugLogService;
_backgroundService = backgroundService;
_timeoutPredictionService = timeoutPredictionService;
_usbManager.setDebugLogService(_appDebugLogService);
_tcpConnector.setDebugLogService(_appDebugLogService);
@ -689,13 +694,23 @@ class MeshCoreConnector extends ChangeNotifier {
updateMessageCallback: _updateMessage,
clearContactPathCallback: clearContactPath,
setContactPathCallback: setContactPath,
calculateTimeoutCallback: (pathLength, messageBytes) =>
calculateTimeout(pathLength: pathLength, messageBytes: messageBytes),
calculateTimeoutCallback: (pathLength, messageBytes, {String? contactKey}) =>
calculateTimeout(pathLength: pathLength, messageBytes: messageBytes, contactKey: contactKey),
getSelfPublicKeyCallback: () => _selfPublicKey,
prepareContactOutboundTextCallback: prepareContactOutboundText,
appSettingsService: appSettingsService,
debugLogService: _appDebugLogService,
recordPathResultCallback: _recordPathResult,
onDeliveryObservedCallback: (contactKey, pathLength, messageBytes, tripTimeMs) {
final secSinceRx = DateTime.now().difference(_lastRxTime).inSeconds;
_timeoutPredictionService?.recordObservation(
contactKey: contactKey,
pathLength: pathLength,
messageBytes: messageBytes,
tripTimeMs: tripTimeMs,
secondsSinceLastRx: secSinceRx,
);
},
);
}
@ -2498,6 +2513,7 @@ class MeshCoreConnector extends ChangeNotifier {
void _handleFrame(List<int> data) {
if (data.isEmpty) return;
_lastRxTime = DateTime.now();
final frame = Uint8List.fromList(data);
_receivedFramesController.add(frame);
@ -2876,7 +2892,21 @@ class MeshCoreConnector extends ChangeNotifier {
/// Calculate timeout for a message based on radio settings and path length
/// Returns timeout in milliseconds, considering number of hops
int calculateTimeout({required int pathLength, int messageBytes = 100}) {
int calculateTimeout({
required int pathLength,
int messageBytes = 100,
String? contactKey,
}) {
// Try ML-based prediction first
final secSinceRx = DateTime.now().difference(_lastRxTime).inSeconds;
final mlTimeout = _timeoutPredictionService?.predictTimeout(
contactKey: contactKey,
pathLength: pathLength,
messageBytes: messageBytes,
secondsSinceLastRx: secSinceRx,
);
if (mlTimeout != null) return mlTimeout;
// If we have radio settings, use them for accurate calculation
if (_currentFreqHz != null &&
_currentBwHz != null &&

View file

@ -19,6 +19,7 @@ import 'services/app_debug_log_service.dart';
import 'services/background_service.dart';
import 'services/map_tile_cache_service.dart';
import 'services/chat_text_scale_service.dart';
import 'services/timeout_prediction_service.dart';
import 'storage/prefs_manager.dart';
import 'utils/app_logger.dart';
@ -39,6 +40,7 @@ void main() async {
final backgroundService = BackgroundService();
final mapTileCacheService = MapTileCacheService();
final chatTextScaleService = ChatTextScaleService();
final timeoutPredictionService = TimeoutPredictionService(storage);
// Load settings
await appSettingsService.loadSettings();
@ -56,6 +58,7 @@ void main() async {
_registerThirdPartyLicenses();
await chatTextScaleService.initialize();
await timeoutPredictionService.initialize();
// Wire up connector with services
connector.initialize(
@ -65,6 +68,7 @@ void main() async {
bleDebugLogService: bleDebugLogService,
appDebugLogService: appDebugLogService,
backgroundService: backgroundService,
timeoutPredictionService: timeoutPredictionService,
);
await connector.loadContactCache();
@ -86,6 +90,7 @@ void main() async {
appDebugLogService: appDebugLogService,
mapTileCacheService: mapTileCacheService,
chatTextScaleService: chatTextScaleService,
timeoutPredictionService: timeoutPredictionService,
),
);
}
@ -121,6 +126,7 @@ class MeshCoreApp extends StatelessWidget {
final AppDebugLogService appDebugLogService;
final MapTileCacheService mapTileCacheService;
final ChatTextScaleService chatTextScaleService;
final TimeoutPredictionService timeoutPredictionService;
const MeshCoreApp({
super.key,
@ -133,6 +139,7 @@ class MeshCoreApp extends StatelessWidget {
required this.appDebugLogService,
required this.mapTileCacheService,
required this.chatTextScaleService,
required this.timeoutPredictionService,
});
@override
@ -148,6 +155,7 @@ class MeshCoreApp extends StatelessWidget {
ChangeNotifierProvider.value(value: chatTextScaleService),
Provider.value(value: storage),
Provider.value(value: mapTileCacheService),
ChangeNotifierProvider.value(value: timeoutPredictionService),
],
child: Consumer<AppSettingsService>(
builder: (context, settingsService, child) {

View file

@ -0,0 +1,43 @@
class DeliveryObservation {
final String contactKey;
final int pathLength;
final int messageBytes;
final int secondsSinceLastRx;
final bool isFlood;
final int deliveryMs;
final DateTime timestamp;
DeliveryObservation({
required this.contactKey,
required this.pathLength,
required this.messageBytes,
required this.secondsSinceLastRx,
required this.isFlood,
required this.deliveryMs,
required this.timestamp,
});
Map<String, dynamic> toJson() {
return {
'contact_key': contactKey,
'path_length': pathLength,
'message_bytes': messageBytes,
'seconds_since_last_rx': secondsSinceLastRx,
'is_flood': isFlood,
'delivery_ms': deliveryMs,
'timestamp': timestamp.toIso8601String(),
};
}
factory DeliveryObservation.fromJson(Map<String, dynamic> json) {
return DeliveryObservation(
contactKey: json['contact_key'] as String,
pathLength: json['path_length'] as int,
messageBytes: json['message_bytes'] as int,
secondsSinceLastRx: json['seconds_since_last_rx'] as int? ?? 0,
isFlood: json['is_flood'] as bool,
deliveryMs: json['delivery_ms'] as int,
timestamp: DateTime.parse(json['timestamp'] as String),
);
}
}

View file

@ -58,12 +58,13 @@ class MessageRetryService extends ChangeNotifier {
Function(Message)? _updateMessageCallback;
Function(Contact)? _clearContactPathCallback;
Function(Contact, Uint8List, int)? _setContactPathCallback;
Function(int, int)? _calculateTimeoutCallback;
Function(int, int, {String? contactKey})? _calculateTimeoutCallback;
Uint8List? Function()? _getSelfPublicKeyCallback;
String Function(Contact, String)? _prepareContactOutboundTextCallback;
AppSettingsService? _appSettingsService;
AppDebugLogService? _debugLogService;
Function(String, PathSelection, bool, int?)? _recordPathResultCallback;
Function(String, int, int, int)? _onDeliveryObservedCallback;
MessageRetryService();
@ -73,12 +74,14 @@ class MessageRetryService extends ChangeNotifier {
required Function(Message) updateMessageCallback,
Function(Contact)? clearContactPathCallback,
Function(Contact, Uint8List, int)? setContactPathCallback,
Function(int pathLength, int messageBytes)? calculateTimeoutCallback,
Function(int pathLength, int messageBytes, {String? contactKey})? calculateTimeoutCallback,
Uint8List? Function()? getSelfPublicKeyCallback,
String Function(Contact, String)? prepareContactOutboundTextCallback,
AppSettingsService? appSettingsService,
AppDebugLogService? debugLogService,
Function(String, PathSelection, bool, int?)? recordPathResultCallback,
Function(String contactKey, int pathLength, int messageBytes, int tripTimeMs)?
onDeliveryObservedCallback,
}) {
_sendMessageCallback = sendMessageCallback;
_addMessageCallback = addMessageCallback;
@ -91,6 +94,7 @@ class MessageRetryService extends ChangeNotifier {
_appSettingsService = appSettingsService;
_debugLogService = debugLogService;
_recordPathResultCallback = recordPathResultCallback;
_onDeliveryObservedCallback = onDeliveryObservedCallback;
}
/// Compute expected ACK hash using same algorithm as firmware:
@ -423,25 +427,33 @@ class MessageRetryService extends ChangeNotifier {
);
}
// Use device-provided timeout, or calculate from radio settings if timeout is 0 or invalid
// Calculate timeout: prefer ML prediction, then device-provided, then physics fallback
int pathLengthValue;
if (selection != null) {
pathLengthValue = selection.useFlood ? -1 : selection.hopCount;
if (pathLengthValue < 0) pathLengthValue = contact.pathLength;
} else if (message.pathLength != null) {
pathLengthValue = message.pathLength!;
} else {
pathLengthValue = contact.pathLength;
}
int actualTimeout = timeoutMs;
if (timeoutMs <= 0 && _calculateTimeoutCallback != null) {
int pathLengthValue;
if (selection != null) {
pathLengthValue = selection.useFlood ? -1 : selection.hopCount;
if (pathLengthValue < 0) pathLengthValue = contact.pathLength;
} else if (message.pathLength != null) {
pathLengthValue = message.pathLength!;
} else {
pathLengthValue = contact.pathLength;
}
actualTimeout = _calculateTimeoutCallback!(
if (_calculateTimeoutCallback != null) {
final calculated = _calculateTimeoutCallback!(
pathLengthValue,
message.text.length,
contactKey: contact.publicKeyHex,
);
debugPrint(
'Using calculated timeout: ${actualTimeout}ms for path length $pathLengthValue',
);
// calculateTimeout tries ML first, falls back to physics.
// Use calculated value if device didn't provide one, or if ML
// produced a tighter prediction than the device's estimate.
if (timeoutMs <= 0 || calculated < timeoutMs) {
actualTimeout = calculated;
debugPrint(
'Using calculated timeout: ${actualTimeout}ms for path length $pathLengthValue',
);
}
}
final updatedMessage = message.copyWith(
@ -738,6 +750,14 @@ class MessageRetryService extends ChangeNotifier {
true,
tripTimeMs,
);
if (_onDeliveryObservedCallback != null && tripTimeMs > 0) {
_onDeliveryObservedCallback!(
contact.publicKeyHex,
message.pathLength ?? 0,
message.text.length,
tripTimeMs,
);
}
_onMessageResolved(matchedMessageId, contact.publicKeyHex);
}

View file

@ -1,4 +1,5 @@
import 'dart:convert';
import '../models/delivery_observation.dart';
import '../models/path_history.dart';
import '../storage/prefs_manager.dart';
@ -6,6 +7,8 @@ class StorageService {
static const String _pathHistoryPrefix = 'path_history_';
static const String _pendingMessagesKey = 'pending_messages';
static const String _repeaterPasswordsKey = 'repeater_passwords';
static const String _deliveryObservationsKey = 'delivery_observations';
static const String _timeoutModelKey = 'timeout_ml_model';
Future<void> savePathHistory(
String contactPubKeyHex,
@ -122,4 +125,51 @@ class StorageService {
final prefs = PrefsManager.instance;
await prefs.remove(_repeaterPasswordsKey);
}
Future<void> saveDeliveryObservations(
List<DeliveryObservation> observations,
) async {
final prefs = PrefsManager.instance;
final jsonStr = jsonEncode(observations.map((o) => o.toJson()).toList());
await prefs.setString(_deliveryObservationsKey, jsonStr);
}
Future<List<DeliveryObservation>> loadDeliveryObservations() async {
final prefs = PrefsManager.instance;
final jsonStr = prefs.getString(_deliveryObservationsKey);
if (jsonStr == null) return [];
try {
final list = jsonDecode(jsonStr) as List;
return list
.map(
(e) =>
DeliveryObservation.fromJson(e as Map<String, dynamic>),
)
.toList();
} catch (e) {
return [];
}
}
Future<void> clearDeliveryObservations() async {
final prefs = PrefsManager.instance;
await prefs.remove(_deliveryObservationsKey);
}
Future<void> saveTimeoutModel(String modelJson) async {
final prefs = PrefsManager.instance;
await prefs.setString(_timeoutModelKey, modelJson);
}
Future<String?> loadTimeoutModel() async {
final prefs = PrefsManager.instance;
return prefs.getString(_timeoutModelKey);
}
Future<void> clearTimeoutModel() async {
final prefs = PrefsManager.instance;
await prefs.remove(_timeoutModelKey);
}
}

View file

@ -0,0 +1,224 @@
import 'dart:convert';
import 'dart:math';
import 'package:flutter/foundation.dart';
import 'package:ml_algo/ml_algo.dart';
import 'package:ml_dataframe/ml_dataframe.dart';
import '../models/delivery_observation.dart';
import 'storage_service.dart';
class _ContactStats {
int count = 0;
double _sum = 0;
double _sumSq = 0;
void add(double ms) {
count++;
_sum += ms;
_sumSq += ms * ms;
}
double get mean => _sum / count;
double get stdDev => sqrt((_sumSq / count) - (mean * mean));
}
class TimeoutPredictionService extends ChangeNotifier {
final StorageService? _storage;
static const int minObservations = 10;
static const int maxObservations = 100;
static const int _retrainInterval = 5;
static const double _safetyMargin = 1.5;
static const int _minTimeoutMs = 2000;
static const int _maxTimeoutMs = 120000;
static const int _minContactObservations = 10;
List<DeliveryObservation> _observations = [];
LinearRegressor? _model;
List<String> _activeFeatures = [];
int _observationsSinceLastTrain = 0;
final Map<String, _ContactStats> _contactStats = {};
TimeoutPredictionService(StorageService storage) : _storage = storage;
TimeoutPredictionService.noStorage() : _storage = null;
int get observationCount => _observations.length;
bool get hasModel => _model != null;
Future<void> initialize() async {
_observations = await _storage?.loadDeliveryObservations() ?? [];
_rebuildContactStats();
if (_observations.length >= minObservations) {
_trainModel();
}
debugPrint(
'TimeoutPrediction: initialized with ${_observations.length} observations, '
'model=${_model != null ? "ready" : "waiting for data"}',
);
}
void recordObservation({
required String contactKey,
required int pathLength,
required int messageBytes,
required int tripTimeMs,
int secondsSinceLastRx = 0,
}) {
final observation = DeliveryObservation(
contactKey: contactKey,
pathLength: pathLength,
messageBytes: messageBytes,
secondsSinceLastRx: secondsSinceLastRx,
isFlood: pathLength < 0,
deliveryMs: tripTimeMs,
timestamp: DateTime.now(),
);
_observations.add(observation);
if (_observations.length > maxObservations) {
_observations.removeAt(0);
}
_contactStats.putIfAbsent(contactKey, () => _ContactStats());
_contactStats[contactKey]!.add(tripTimeMs.toDouble());
_observationsSinceLastTrain++;
if (_observationsSinceLastTrain >= _retrainInterval &&
_observations.length >= minObservations) {
_trainModel();
}
_storage?.saveDeliveryObservations(_observations);
debugPrint(
'TimeoutPrediction: recorded ${tripTimeMs}ms for $pathLength hops '
'(${_observations.length} total)',
);
}
int? predictTimeout({
String? contactKey,
required int pathLength,
required int messageBytes,
int secondsSinceLastRx = 0,
}) {
if (_model == null) return null;
try {
if (_activeFeatures.isEmpty) return null;
final allFeatures = {
'pathLength': pathLength.toDouble(),
'messageBytes': messageBytes.toDouble(),
'secSinceRx': secondsSinceLastRx.toDouble(),
'isFlood': pathLength < 0 ? 1.0 : 0.0,
};
final row = _activeFeatures.map((f) => allFeatures[f]!).toList();
final features = DataFrame(
[row],
headerExists: false,
header: _activeFeatures,
);
final prediction = _model!.predict(features);
final rawValue = prediction.rows.first.first;
var predictedMs = (rawValue is double) ? rawValue : (rawValue as num).toDouble();
debugPrint(
'TimeoutPrediction: raw prediction=$predictedMs for '
'pathLength=$pathLength, messageBytes=$messageBytes, '
'features=$_activeFeatures',
);
// Sanity check: if prediction is negative or zero, fall back
if (predictedMs <= 0) return null;
// Blend with per-contact mean if enough data
if (contactKey != null) {
final stats = _contactStats[contactKey];
if (stats != null && stats.count >= _minContactObservations) {
predictedMs = 0.5 * predictedMs + 0.5 * stats.mean;
}
}
final timeout =
(predictedMs * _safetyMargin).ceil().clamp(_minTimeoutMs, _maxTimeoutMs);
debugPrint(
'TimeoutPrediction: ML timeout ${timeout}ms '
'(raw: ${predictedMs.round()}ms, contact: $contactKey)',
);
return timeout;
} catch (e) {
debugPrint('TimeoutPrediction: prediction failed: $e');
return null;
}
}
void _trainModel() {
try {
// Build feature columns, then exclude any with zero variance
// (ml_algo's OLS produces all-zero coefficients for singular matrices)
final allNames = ['pathLength', 'messageBytes', 'secSinceRx', 'isFlood'];
final allExtractors = <double Function(DeliveryObservation)>[
(o) => o.pathLength.toDouble(),
(o) => o.messageBytes.toDouble(),
(o) => o.secondsSinceLastRx.toDouble(),
(o) => o.isFlood ? 1.0 : 0.0,
];
_activeFeatures = [];
for (var i = 0; i < allNames.length; i++) {
final values = _observations.map(allExtractors[i]).toSet();
if (values.length > 1) _activeFeatures.add(allNames[i]);
}
if (_activeFeatures.isEmpty) {
debugPrint('TimeoutPrediction: no features with variance, skipping training');
return;
}
final header = [..._activeFeatures, 'deliveryMs'];
final rows = _observations.map((o) {
final row = <double>[];
for (var i = 0; i < allNames.length; i++) {
if (_activeFeatures.contains(allNames[i])) {
row.add(allExtractors[i](o));
}
}
row.add(o.deliveryMs.toDouble());
return row;
});
final data = DataFrame(
[header, ...rows],
headerExists: true,
);
_model = LinearRegressor(data, 'deliveryMs');
_observationsSinceLastTrain = 0;
// Log training summary with sample predictions
final avgMs = _observations.map((o) => o.deliveryMs).reduce((a, b) => a + b) /
_observations.length;
debugPrint(
'TimeoutPrediction: trained on ${_observations.length} observations '
'(avg: ${avgMs.round()}ms, features: $_activeFeatures)',
);
final modelJson = jsonEncode(_model!.toJson());
_storage?.saveTimeoutModel(modelJson);
notifyListeners();
} catch (e) {
debugPrint('TimeoutPrediction: training failed: $e');
}
}
void _rebuildContactStats() {
_contactStats.clear();
for (final obs in _observations) {
_contactStats.putIfAbsent(obs.contactKey, () => _ContactStats());
_contactStats[obs.contactKey]!.add(obs.deliveryMs.toDouble());
}
}
}