mirror of
https://github.com/zjs81/meshcore-open.git
synced 2026-04-20 22:13:48 +00:00
feat: add ML-based adaptive timeout prediction using LinearRegressor
Train a linear regression model on actual message delivery times to predict tighter timeouts, replacing worst-case physics estimates. Features: path length, message bytes, seconds since last RX, flood mode. Global model with per-contact blending after 10+ observations per contact. Falls back to existing physics formula when model has insufficient data.
This commit is contained in:
parent
8b280b37be
commit
2ee2358ecc
9 changed files with 683 additions and 20 deletions
|
|
@ -19,6 +19,7 @@ import '../services/message_retry_service.dart';
|
|||
import '../services/path_history_service.dart';
|
||||
import '../services/app_settings_service.dart';
|
||||
import '../services/background_service.dart';
|
||||
import '../services/timeout_prediction_service.dart';
|
||||
import '../services/notification_service.dart';
|
||||
import 'meshcore_connector_usb.dart';
|
||||
import 'meshcore_connector_tcp.dart';
|
||||
|
|
@ -166,6 +167,8 @@ class MeshCoreConnector extends ChangeNotifier {
|
|||
bool _isLoadingContacts = false;
|
||||
bool _isLoadingChannels = false;
|
||||
bool _hasLoadedChannels = false;
|
||||
TimeoutPredictionService? _timeoutPredictionService;
|
||||
DateTime _lastRxTime = DateTime.now();
|
||||
bool _batteryRequested = false;
|
||||
bool _awaitingSelfInfo = false;
|
||||
bool _hasReceivedDeviceInfo = false;
|
||||
|
|
@ -668,6 +671,7 @@ class MeshCoreConnector extends ChangeNotifier {
|
|||
BleDebugLogService? bleDebugLogService,
|
||||
AppDebugLogService? appDebugLogService,
|
||||
BackgroundService? backgroundService,
|
||||
TimeoutPredictionService? timeoutPredictionService,
|
||||
}) {
|
||||
_retryService = retryService;
|
||||
_pathHistoryService = pathHistoryService;
|
||||
|
|
@ -675,6 +679,7 @@ class MeshCoreConnector extends ChangeNotifier {
|
|||
_bleDebugLogService = bleDebugLogService;
|
||||
_appDebugLogService = appDebugLogService;
|
||||
_backgroundService = backgroundService;
|
||||
_timeoutPredictionService = timeoutPredictionService;
|
||||
_usbManager.setDebugLogService(_appDebugLogService);
|
||||
_tcpConnector.setDebugLogService(_appDebugLogService);
|
||||
|
||||
|
|
@ -689,13 +694,23 @@ class MeshCoreConnector extends ChangeNotifier {
|
|||
updateMessageCallback: _updateMessage,
|
||||
clearContactPathCallback: clearContactPath,
|
||||
setContactPathCallback: setContactPath,
|
||||
calculateTimeoutCallback: (pathLength, messageBytes) =>
|
||||
calculateTimeout(pathLength: pathLength, messageBytes: messageBytes),
|
||||
calculateTimeoutCallback: (pathLength, messageBytes, {String? contactKey}) =>
|
||||
calculateTimeout(pathLength: pathLength, messageBytes: messageBytes, contactKey: contactKey),
|
||||
getSelfPublicKeyCallback: () => _selfPublicKey,
|
||||
prepareContactOutboundTextCallback: prepareContactOutboundText,
|
||||
appSettingsService: appSettingsService,
|
||||
debugLogService: _appDebugLogService,
|
||||
recordPathResultCallback: _recordPathResult,
|
||||
onDeliveryObservedCallback: (contactKey, pathLength, messageBytes, tripTimeMs) {
|
||||
final secSinceRx = DateTime.now().difference(_lastRxTime).inSeconds;
|
||||
_timeoutPredictionService?.recordObservation(
|
||||
contactKey: contactKey,
|
||||
pathLength: pathLength,
|
||||
messageBytes: messageBytes,
|
||||
tripTimeMs: tripTimeMs,
|
||||
secondsSinceLastRx: secSinceRx,
|
||||
);
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
|
|
@ -2498,6 +2513,7 @@ class MeshCoreConnector extends ChangeNotifier {
|
|||
|
||||
void _handleFrame(List<int> data) {
|
||||
if (data.isEmpty) return;
|
||||
_lastRxTime = DateTime.now();
|
||||
|
||||
final frame = Uint8List.fromList(data);
|
||||
_receivedFramesController.add(frame);
|
||||
|
|
@ -2876,7 +2892,21 @@ class MeshCoreConnector extends ChangeNotifier {
|
|||
|
||||
/// Calculate timeout for a message based on radio settings and path length
|
||||
/// Returns timeout in milliseconds, considering number of hops
|
||||
int calculateTimeout({required int pathLength, int messageBytes = 100}) {
|
||||
int calculateTimeout({
|
||||
required int pathLength,
|
||||
int messageBytes = 100,
|
||||
String? contactKey,
|
||||
}) {
|
||||
// Try ML-based prediction first
|
||||
final secSinceRx = DateTime.now().difference(_lastRxTime).inSeconds;
|
||||
final mlTimeout = _timeoutPredictionService?.predictTimeout(
|
||||
contactKey: contactKey,
|
||||
pathLength: pathLength,
|
||||
messageBytes: messageBytes,
|
||||
secondsSinceLastRx: secSinceRx,
|
||||
);
|
||||
if (mlTimeout != null) return mlTimeout;
|
||||
|
||||
// If we have radio settings, use them for accurate calculation
|
||||
if (_currentFreqHz != null &&
|
||||
_currentBwHz != null &&
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ import 'services/app_debug_log_service.dart';
|
|||
import 'services/background_service.dart';
|
||||
import 'services/map_tile_cache_service.dart';
|
||||
import 'services/chat_text_scale_service.dart';
|
||||
import 'services/timeout_prediction_service.dart';
|
||||
import 'storage/prefs_manager.dart';
|
||||
import 'utils/app_logger.dart';
|
||||
|
||||
|
|
@ -39,6 +40,7 @@ void main() async {
|
|||
final backgroundService = BackgroundService();
|
||||
final mapTileCacheService = MapTileCacheService();
|
||||
final chatTextScaleService = ChatTextScaleService();
|
||||
final timeoutPredictionService = TimeoutPredictionService(storage);
|
||||
|
||||
// Load settings
|
||||
await appSettingsService.loadSettings();
|
||||
|
|
@ -56,6 +58,7 @@ void main() async {
|
|||
_registerThirdPartyLicenses();
|
||||
|
||||
await chatTextScaleService.initialize();
|
||||
await timeoutPredictionService.initialize();
|
||||
|
||||
// Wire up connector with services
|
||||
connector.initialize(
|
||||
|
|
@ -65,6 +68,7 @@ void main() async {
|
|||
bleDebugLogService: bleDebugLogService,
|
||||
appDebugLogService: appDebugLogService,
|
||||
backgroundService: backgroundService,
|
||||
timeoutPredictionService: timeoutPredictionService,
|
||||
);
|
||||
|
||||
await connector.loadContactCache();
|
||||
|
|
@ -86,6 +90,7 @@ void main() async {
|
|||
appDebugLogService: appDebugLogService,
|
||||
mapTileCacheService: mapTileCacheService,
|
||||
chatTextScaleService: chatTextScaleService,
|
||||
timeoutPredictionService: timeoutPredictionService,
|
||||
),
|
||||
);
|
||||
}
|
||||
|
|
@ -121,6 +126,7 @@ class MeshCoreApp extends StatelessWidget {
|
|||
final AppDebugLogService appDebugLogService;
|
||||
final MapTileCacheService mapTileCacheService;
|
||||
final ChatTextScaleService chatTextScaleService;
|
||||
final TimeoutPredictionService timeoutPredictionService;
|
||||
|
||||
const MeshCoreApp({
|
||||
super.key,
|
||||
|
|
@ -133,6 +139,7 @@ class MeshCoreApp extends StatelessWidget {
|
|||
required this.appDebugLogService,
|
||||
required this.mapTileCacheService,
|
||||
required this.chatTextScaleService,
|
||||
required this.timeoutPredictionService,
|
||||
});
|
||||
|
||||
@override
|
||||
|
|
@ -148,6 +155,7 @@ class MeshCoreApp extends StatelessWidget {
|
|||
ChangeNotifierProvider.value(value: chatTextScaleService),
|
||||
Provider.value(value: storage),
|
||||
Provider.value(value: mapTileCacheService),
|
||||
ChangeNotifierProvider.value(value: timeoutPredictionService),
|
||||
],
|
||||
child: Consumer<AppSettingsService>(
|
||||
builder: (context, settingsService, child) {
|
||||
|
|
|
|||
43
lib/models/delivery_observation.dart
Normal file
43
lib/models/delivery_observation.dart
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
class DeliveryObservation {
|
||||
final String contactKey;
|
||||
final int pathLength;
|
||||
final int messageBytes;
|
||||
final int secondsSinceLastRx;
|
||||
final bool isFlood;
|
||||
final int deliveryMs;
|
||||
final DateTime timestamp;
|
||||
|
||||
DeliveryObservation({
|
||||
required this.contactKey,
|
||||
required this.pathLength,
|
||||
required this.messageBytes,
|
||||
required this.secondsSinceLastRx,
|
||||
required this.isFlood,
|
||||
required this.deliveryMs,
|
||||
required this.timestamp,
|
||||
});
|
||||
|
||||
Map<String, dynamic> toJson() {
|
||||
return {
|
||||
'contact_key': contactKey,
|
||||
'path_length': pathLength,
|
||||
'message_bytes': messageBytes,
|
||||
'seconds_since_last_rx': secondsSinceLastRx,
|
||||
'is_flood': isFlood,
|
||||
'delivery_ms': deliveryMs,
|
||||
'timestamp': timestamp.toIso8601String(),
|
||||
};
|
||||
}
|
||||
|
||||
factory DeliveryObservation.fromJson(Map<String, dynamic> json) {
|
||||
return DeliveryObservation(
|
||||
contactKey: json['contact_key'] as String,
|
||||
pathLength: json['path_length'] as int,
|
||||
messageBytes: json['message_bytes'] as int,
|
||||
secondsSinceLastRx: json['seconds_since_last_rx'] as int? ?? 0,
|
||||
isFlood: json['is_flood'] as bool,
|
||||
deliveryMs: json['delivery_ms'] as int,
|
||||
timestamp: DateTime.parse(json['timestamp'] as String),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
@ -58,12 +58,13 @@ class MessageRetryService extends ChangeNotifier {
|
|||
Function(Message)? _updateMessageCallback;
|
||||
Function(Contact)? _clearContactPathCallback;
|
||||
Function(Contact, Uint8List, int)? _setContactPathCallback;
|
||||
Function(int, int)? _calculateTimeoutCallback;
|
||||
Function(int, int, {String? contactKey})? _calculateTimeoutCallback;
|
||||
Uint8List? Function()? _getSelfPublicKeyCallback;
|
||||
String Function(Contact, String)? _prepareContactOutboundTextCallback;
|
||||
AppSettingsService? _appSettingsService;
|
||||
AppDebugLogService? _debugLogService;
|
||||
Function(String, PathSelection, bool, int?)? _recordPathResultCallback;
|
||||
Function(String, int, int, int)? _onDeliveryObservedCallback;
|
||||
|
||||
MessageRetryService();
|
||||
|
||||
|
|
@ -73,12 +74,14 @@ class MessageRetryService extends ChangeNotifier {
|
|||
required Function(Message) updateMessageCallback,
|
||||
Function(Contact)? clearContactPathCallback,
|
||||
Function(Contact, Uint8List, int)? setContactPathCallback,
|
||||
Function(int pathLength, int messageBytes)? calculateTimeoutCallback,
|
||||
Function(int pathLength, int messageBytes, {String? contactKey})? calculateTimeoutCallback,
|
||||
Uint8List? Function()? getSelfPublicKeyCallback,
|
||||
String Function(Contact, String)? prepareContactOutboundTextCallback,
|
||||
AppSettingsService? appSettingsService,
|
||||
AppDebugLogService? debugLogService,
|
||||
Function(String, PathSelection, bool, int?)? recordPathResultCallback,
|
||||
Function(String contactKey, int pathLength, int messageBytes, int tripTimeMs)?
|
||||
onDeliveryObservedCallback,
|
||||
}) {
|
||||
_sendMessageCallback = sendMessageCallback;
|
||||
_addMessageCallback = addMessageCallback;
|
||||
|
|
@ -91,6 +94,7 @@ class MessageRetryService extends ChangeNotifier {
|
|||
_appSettingsService = appSettingsService;
|
||||
_debugLogService = debugLogService;
|
||||
_recordPathResultCallback = recordPathResultCallback;
|
||||
_onDeliveryObservedCallback = onDeliveryObservedCallback;
|
||||
}
|
||||
|
||||
/// Compute expected ACK hash using same algorithm as firmware:
|
||||
|
|
@ -423,25 +427,33 @@ class MessageRetryService extends ChangeNotifier {
|
|||
);
|
||||
}
|
||||
|
||||
// Use device-provided timeout, or calculate from radio settings if timeout is 0 or invalid
|
||||
// Calculate timeout: prefer ML prediction, then device-provided, then physics fallback
|
||||
int pathLengthValue;
|
||||
if (selection != null) {
|
||||
pathLengthValue = selection.useFlood ? -1 : selection.hopCount;
|
||||
if (pathLengthValue < 0) pathLengthValue = contact.pathLength;
|
||||
} else if (message.pathLength != null) {
|
||||
pathLengthValue = message.pathLength!;
|
||||
} else {
|
||||
pathLengthValue = contact.pathLength;
|
||||
}
|
||||
|
||||
int actualTimeout = timeoutMs;
|
||||
if (timeoutMs <= 0 && _calculateTimeoutCallback != null) {
|
||||
int pathLengthValue;
|
||||
if (selection != null) {
|
||||
pathLengthValue = selection.useFlood ? -1 : selection.hopCount;
|
||||
if (pathLengthValue < 0) pathLengthValue = contact.pathLength;
|
||||
} else if (message.pathLength != null) {
|
||||
pathLengthValue = message.pathLength!;
|
||||
} else {
|
||||
pathLengthValue = contact.pathLength;
|
||||
}
|
||||
actualTimeout = _calculateTimeoutCallback!(
|
||||
if (_calculateTimeoutCallback != null) {
|
||||
final calculated = _calculateTimeoutCallback!(
|
||||
pathLengthValue,
|
||||
message.text.length,
|
||||
contactKey: contact.publicKeyHex,
|
||||
);
|
||||
debugPrint(
|
||||
'Using calculated timeout: ${actualTimeout}ms for path length $pathLengthValue',
|
||||
);
|
||||
// calculateTimeout tries ML first, falls back to physics.
|
||||
// Use calculated value if device didn't provide one, or if ML
|
||||
// produced a tighter prediction than the device's estimate.
|
||||
if (timeoutMs <= 0 || calculated < timeoutMs) {
|
||||
actualTimeout = calculated;
|
||||
debugPrint(
|
||||
'Using calculated timeout: ${actualTimeout}ms for path length $pathLengthValue',
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
final updatedMessage = message.copyWith(
|
||||
|
|
@ -738,6 +750,14 @@ class MessageRetryService extends ChangeNotifier {
|
|||
true,
|
||||
tripTimeMs,
|
||||
);
|
||||
if (_onDeliveryObservedCallback != null && tripTimeMs > 0) {
|
||||
_onDeliveryObservedCallback!(
|
||||
contact.publicKeyHex,
|
||||
message.pathLength ?? 0,
|
||||
message.text.length,
|
||||
tripTimeMs,
|
||||
);
|
||||
}
|
||||
_onMessageResolved(matchedMessageId, contact.publicKeyHex);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import 'dart:convert';
|
||||
import '../models/delivery_observation.dart';
|
||||
import '../models/path_history.dart';
|
||||
import '../storage/prefs_manager.dart';
|
||||
|
||||
|
|
@ -6,6 +7,8 @@ class StorageService {
|
|||
static const String _pathHistoryPrefix = 'path_history_';
|
||||
static const String _pendingMessagesKey = 'pending_messages';
|
||||
static const String _repeaterPasswordsKey = 'repeater_passwords';
|
||||
static const String _deliveryObservationsKey = 'delivery_observations';
|
||||
static const String _timeoutModelKey = 'timeout_ml_model';
|
||||
|
||||
Future<void> savePathHistory(
|
||||
String contactPubKeyHex,
|
||||
|
|
@ -122,4 +125,51 @@ class StorageService {
|
|||
final prefs = PrefsManager.instance;
|
||||
await prefs.remove(_repeaterPasswordsKey);
|
||||
}
|
||||
|
||||
Future<void> saveDeliveryObservations(
|
||||
List<DeliveryObservation> observations,
|
||||
) async {
|
||||
final prefs = PrefsManager.instance;
|
||||
final jsonStr = jsonEncode(observations.map((o) => o.toJson()).toList());
|
||||
await prefs.setString(_deliveryObservationsKey, jsonStr);
|
||||
}
|
||||
|
||||
Future<List<DeliveryObservation>> loadDeliveryObservations() async {
|
||||
final prefs = PrefsManager.instance;
|
||||
final jsonStr = prefs.getString(_deliveryObservationsKey);
|
||||
|
||||
if (jsonStr == null) return [];
|
||||
|
||||
try {
|
||||
final list = jsonDecode(jsonStr) as List;
|
||||
return list
|
||||
.map(
|
||||
(e) =>
|
||||
DeliveryObservation.fromJson(e as Map<String, dynamic>),
|
||||
)
|
||||
.toList();
|
||||
} catch (e) {
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
Future<void> clearDeliveryObservations() async {
|
||||
final prefs = PrefsManager.instance;
|
||||
await prefs.remove(_deliveryObservationsKey);
|
||||
}
|
||||
|
||||
Future<void> saveTimeoutModel(String modelJson) async {
|
||||
final prefs = PrefsManager.instance;
|
||||
await prefs.setString(_timeoutModelKey, modelJson);
|
||||
}
|
||||
|
||||
Future<String?> loadTimeoutModel() async {
|
||||
final prefs = PrefsManager.instance;
|
||||
return prefs.getString(_timeoutModelKey);
|
||||
}
|
||||
|
||||
Future<void> clearTimeoutModel() async {
|
||||
final prefs = PrefsManager.instance;
|
||||
await prefs.remove(_timeoutModelKey);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
224
lib/services/timeout_prediction_service.dart
Normal file
224
lib/services/timeout_prediction_service.dart
Normal file
|
|
@ -0,0 +1,224 @@
|
|||
import 'dart:convert';
|
||||
import 'dart:math';
|
||||
import 'package:flutter/foundation.dart';
|
||||
import 'package:ml_algo/ml_algo.dart';
|
||||
import 'package:ml_dataframe/ml_dataframe.dart';
|
||||
import '../models/delivery_observation.dart';
|
||||
import 'storage_service.dart';
|
||||
|
||||
class _ContactStats {
|
||||
int count = 0;
|
||||
double _sum = 0;
|
||||
double _sumSq = 0;
|
||||
|
||||
void add(double ms) {
|
||||
count++;
|
||||
_sum += ms;
|
||||
_sumSq += ms * ms;
|
||||
}
|
||||
|
||||
double get mean => _sum / count;
|
||||
double get stdDev => sqrt((_sumSq / count) - (mean * mean));
|
||||
}
|
||||
|
||||
class TimeoutPredictionService extends ChangeNotifier {
|
||||
final StorageService? _storage;
|
||||
|
||||
static const int minObservations = 10;
|
||||
static const int maxObservations = 100;
|
||||
static const int _retrainInterval = 5;
|
||||
static const double _safetyMargin = 1.5;
|
||||
static const int _minTimeoutMs = 2000;
|
||||
static const int _maxTimeoutMs = 120000;
|
||||
static const int _minContactObservations = 10;
|
||||
|
||||
List<DeliveryObservation> _observations = [];
|
||||
LinearRegressor? _model;
|
||||
List<String> _activeFeatures = [];
|
||||
int _observationsSinceLastTrain = 0;
|
||||
final Map<String, _ContactStats> _contactStats = {};
|
||||
|
||||
TimeoutPredictionService(StorageService storage) : _storage = storage;
|
||||
TimeoutPredictionService.noStorage() : _storage = null;
|
||||
|
||||
int get observationCount => _observations.length;
|
||||
bool get hasModel => _model != null;
|
||||
|
||||
Future<void> initialize() async {
|
||||
_observations = await _storage?.loadDeliveryObservations() ?? [];
|
||||
_rebuildContactStats();
|
||||
|
||||
if (_observations.length >= minObservations) {
|
||||
_trainModel();
|
||||
}
|
||||
|
||||
debugPrint(
|
||||
'TimeoutPrediction: initialized with ${_observations.length} observations, '
|
||||
'model=${_model != null ? "ready" : "waiting for data"}',
|
||||
);
|
||||
}
|
||||
|
||||
void recordObservation({
|
||||
required String contactKey,
|
||||
required int pathLength,
|
||||
required int messageBytes,
|
||||
required int tripTimeMs,
|
||||
int secondsSinceLastRx = 0,
|
||||
}) {
|
||||
final observation = DeliveryObservation(
|
||||
contactKey: contactKey,
|
||||
pathLength: pathLength,
|
||||
messageBytes: messageBytes,
|
||||
secondsSinceLastRx: secondsSinceLastRx,
|
||||
isFlood: pathLength < 0,
|
||||
deliveryMs: tripTimeMs,
|
||||
timestamp: DateTime.now(),
|
||||
);
|
||||
|
||||
_observations.add(observation);
|
||||
if (_observations.length > maxObservations) {
|
||||
_observations.removeAt(0);
|
||||
}
|
||||
|
||||
_contactStats.putIfAbsent(contactKey, () => _ContactStats());
|
||||
_contactStats[contactKey]!.add(tripTimeMs.toDouble());
|
||||
|
||||
_observationsSinceLastTrain++;
|
||||
if (_observationsSinceLastTrain >= _retrainInterval &&
|
||||
_observations.length >= minObservations) {
|
||||
_trainModel();
|
||||
}
|
||||
|
||||
_storage?.saveDeliveryObservations(_observations);
|
||||
debugPrint(
|
||||
'TimeoutPrediction: recorded ${tripTimeMs}ms for $pathLength hops '
|
||||
'(${_observations.length} total)',
|
||||
);
|
||||
}
|
||||
|
||||
int? predictTimeout({
|
||||
String? contactKey,
|
||||
required int pathLength,
|
||||
required int messageBytes,
|
||||
int secondsSinceLastRx = 0,
|
||||
}) {
|
||||
if (_model == null) return null;
|
||||
|
||||
try {
|
||||
if (_activeFeatures.isEmpty) return null;
|
||||
|
||||
final allFeatures = {
|
||||
'pathLength': pathLength.toDouble(),
|
||||
'messageBytes': messageBytes.toDouble(),
|
||||
'secSinceRx': secondsSinceLastRx.toDouble(),
|
||||
'isFlood': pathLength < 0 ? 1.0 : 0.0,
|
||||
};
|
||||
final row = _activeFeatures.map((f) => allFeatures[f]!).toList();
|
||||
|
||||
final features = DataFrame(
|
||||
[row],
|
||||
headerExists: false,
|
||||
header: _activeFeatures,
|
||||
);
|
||||
|
||||
final prediction = _model!.predict(features);
|
||||
final rawValue = prediction.rows.first.first;
|
||||
var predictedMs = (rawValue is double) ? rawValue : (rawValue as num).toDouble();
|
||||
|
||||
debugPrint(
|
||||
'TimeoutPrediction: raw prediction=$predictedMs for '
|
||||
'pathLength=$pathLength, messageBytes=$messageBytes, '
|
||||
'features=$_activeFeatures',
|
||||
);
|
||||
|
||||
// Sanity check: if prediction is negative or zero, fall back
|
||||
if (predictedMs <= 0) return null;
|
||||
|
||||
// Blend with per-contact mean if enough data
|
||||
if (contactKey != null) {
|
||||
final stats = _contactStats[contactKey];
|
||||
if (stats != null && stats.count >= _minContactObservations) {
|
||||
predictedMs = 0.5 * predictedMs + 0.5 * stats.mean;
|
||||
}
|
||||
}
|
||||
|
||||
final timeout =
|
||||
(predictedMs * _safetyMargin).ceil().clamp(_minTimeoutMs, _maxTimeoutMs);
|
||||
debugPrint(
|
||||
'TimeoutPrediction: ML timeout ${timeout}ms '
|
||||
'(raw: ${predictedMs.round()}ms, contact: $contactKey)',
|
||||
);
|
||||
return timeout;
|
||||
} catch (e) {
|
||||
debugPrint('TimeoutPrediction: prediction failed: $e');
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
void _trainModel() {
|
||||
try {
|
||||
// Build feature columns, then exclude any with zero variance
|
||||
// (ml_algo's OLS produces all-zero coefficients for singular matrices)
|
||||
final allNames = ['pathLength', 'messageBytes', 'secSinceRx', 'isFlood'];
|
||||
final allExtractors = <double Function(DeliveryObservation)>[
|
||||
(o) => o.pathLength.toDouble(),
|
||||
(o) => o.messageBytes.toDouble(),
|
||||
(o) => o.secondsSinceLastRx.toDouble(),
|
||||
(o) => o.isFlood ? 1.0 : 0.0,
|
||||
];
|
||||
|
||||
_activeFeatures = [];
|
||||
for (var i = 0; i < allNames.length; i++) {
|
||||
final values = _observations.map(allExtractors[i]).toSet();
|
||||
if (values.length > 1) _activeFeatures.add(allNames[i]);
|
||||
}
|
||||
|
||||
if (_activeFeatures.isEmpty) {
|
||||
debugPrint('TimeoutPrediction: no features with variance, skipping training');
|
||||
return;
|
||||
}
|
||||
|
||||
final header = [..._activeFeatures, 'deliveryMs'];
|
||||
final rows = _observations.map((o) {
|
||||
final row = <double>[];
|
||||
for (var i = 0; i < allNames.length; i++) {
|
||||
if (_activeFeatures.contains(allNames[i])) {
|
||||
row.add(allExtractors[i](o));
|
||||
}
|
||||
}
|
||||
row.add(o.deliveryMs.toDouble());
|
||||
return row;
|
||||
});
|
||||
|
||||
final data = DataFrame(
|
||||
[header, ...rows],
|
||||
headerExists: true,
|
||||
);
|
||||
|
||||
_model = LinearRegressor(data, 'deliveryMs');
|
||||
_observationsSinceLastTrain = 0;
|
||||
|
||||
// Log training summary with sample predictions
|
||||
final avgMs = _observations.map((o) => o.deliveryMs).reduce((a, b) => a + b) /
|
||||
_observations.length;
|
||||
debugPrint(
|
||||
'TimeoutPrediction: trained on ${_observations.length} observations '
|
||||
'(avg: ${avgMs.round()}ms, features: $_activeFeatures)',
|
||||
);
|
||||
|
||||
final modelJson = jsonEncode(_model!.toJson());
|
||||
_storage?.saveTimeoutModel(modelJson);
|
||||
notifyListeners();
|
||||
} catch (e) {
|
||||
debugPrint('TimeoutPrediction: training failed: $e');
|
||||
}
|
||||
}
|
||||
|
||||
void _rebuildContactStats() {
|
||||
_contactStats.clear();
|
||||
for (final obs in _observations) {
|
||||
_contactStats.putIfAbsent(obs.contactKey, () => _ContactStats());
|
||||
_contactStats[obs.contactKey]!.add(obs.deliveryMs.toDouble());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -69,6 +69,8 @@ dependencies:
|
|||
material_symbols_icons: ^4.2906.0
|
||||
web: ^1.1.1
|
||||
flutter_svg: ^2.0.10+1
|
||||
ml_algo: ^16.0.0
|
||||
ml_dataframe: ^1.0.0
|
||||
|
||||
dev_dependencies:
|
||||
flutter_test:
|
||||
|
|
|
|||
122
test/services/ml_algo_sanity_test.dart
Normal file
122
test/services/ml_algo_sanity_test.dart
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
import 'package:flutter/foundation.dart';
|
||||
import 'package:flutter_test/flutter_test.dart';
|
||||
import 'package:ml_algo/ml_algo.dart';
|
||||
import 'package:ml_dataframe/ml_dataframe.dart';
|
||||
|
||||
void main() {
|
||||
test('LinearRegressor basic sanity check', () {
|
||||
// Simple: y = 2x + 100
|
||||
final data = DataFrame([
|
||||
[1.0, 102.0],
|
||||
[2.0, 104.0],
|
||||
[3.0, 106.0],
|
||||
[4.0, 108.0],
|
||||
[5.0, 110.0],
|
||||
[10.0, 120.0],
|
||||
[20.0, 140.0],
|
||||
[50.0, 200.0],
|
||||
[0.0, 100.0],
|
||||
[100.0, 300.0],
|
||||
], headerExists: false, header: ['x', 'y']);
|
||||
|
||||
debugPrint('Training data columns: ${data.header}');
|
||||
debugPrint('Training data rows: ${data.rows.length}');
|
||||
|
||||
final model = LinearRegressor(data, 'y');
|
||||
|
||||
final testDf = DataFrame(
|
||||
[[25.0]],
|
||||
headerExists: false,
|
||||
header: ['x'],
|
||||
);
|
||||
|
||||
final prediction = model.predict(testDf);
|
||||
final value = prediction.rows.first.first;
|
||||
debugPrint('Predict x=25 → y=$value (expected ~150)');
|
||||
expect((value as num).toDouble(), closeTo(150, 5));
|
||||
});
|
||||
|
||||
test('LinearRegressor multi-feature with constant column produces zeros', () {
|
||||
// isFlood=0 for all rows → zero-variance column → singular matrix
|
||||
final data = DataFrame([
|
||||
[0.0, 50.0, 14.0, 0.0, 1900.0],
|
||||
[0.0, 80.0, 14.0, 0.0, 2200.0],
|
||||
[2.0, 50.0, 14.0, 0.0, 5000.0],
|
||||
[4.0, 50.0, 14.0, 0.0, 9500.0],
|
||||
], headerExists: false, header: [
|
||||
'pathLength', 'messageBytes', 'hourOfDay', 'isFlood', 'deliveryMs',
|
||||
]);
|
||||
|
||||
final model = LinearRegressor(data, 'deliveryMs');
|
||||
final testDf = DataFrame(
|
||||
[[2.0, 50.0, 14.0, 0.0]],
|
||||
headerExists: false,
|
||||
header: ['pathLength', 'messageBytes', 'hourOfDay', 'isFlood'],
|
||||
);
|
||||
final pred = model.predict(testDf).rows.first.first;
|
||||
debugPrint('With constant isFlood column: hops=2 → ${(pred as num).round()}ms (likely 0)');
|
||||
});
|
||||
|
||||
test('LinearRegressor 2-feature works correctly', () {
|
||||
// Just pathLength + messageBytes → deliveryMs
|
||||
final data = DataFrame([
|
||||
[0.0, 50.0, 1900.0],
|
||||
[0.0, 80.0, 2200.0],
|
||||
[2.0, 50.0, 5000.0],
|
||||
[2.0, 80.0, 5500.0],
|
||||
[4.0, 50.0, 9500.0],
|
||||
[4.0, 80.0, 10000.0],
|
||||
[0.0, 30.0, 1800.0],
|
||||
[2.0, 30.0, 4800.0],
|
||||
[4.0, 30.0, 9000.0],
|
||||
[0.0, 60.0, 2000.0],
|
||||
], headerExists: false, header: ['pathLength', 'messageBytes', 'deliveryMs']);
|
||||
|
||||
final model = LinearRegressor(data, 'deliveryMs');
|
||||
|
||||
for (final hops in [0.0, 2.0, 4.0]) {
|
||||
final testDf = DataFrame(
|
||||
[[hops, 50.0]],
|
||||
headerExists: false,
|
||||
header: ['pathLength', 'messageBytes'],
|
||||
);
|
||||
final pred = model.predict(testDf).rows.first.first;
|
||||
debugPrint('2-feature: hops=$hops → ${(pred as num).round()}ms');
|
||||
}
|
||||
});
|
||||
|
||||
test('LinearRegressor multi-feature with variance in all columns', () {
|
||||
// Mix flood and direct so isFlood has variance
|
||||
final data = DataFrame([
|
||||
[0.0, 50.0, 14.0, 0.0, 1900.0],
|
||||
[0.0, 80.0, 10.0, 0.0, 2200.0],
|
||||
[2.0, 50.0, 16.0, 0.0, 5000.0],
|
||||
[2.0, 80.0, 20.0, 0.0, 5500.0],
|
||||
[4.0, 50.0, 8.0, 0.0, 9500.0],
|
||||
[4.0, 80.0, 12.0, 0.0, 10000.0],
|
||||
[-1.0, 40.0, 14.0, 1.0, 5000.0],
|
||||
[-1.0, 60.0, 18.0, 1.0, 6500.0],
|
||||
[-1.0, 30.0, 10.0, 1.0, 4000.0],
|
||||
[-1.0, 80.0, 22.0, 1.0, 7000.0],
|
||||
], headerExists: false, header: [
|
||||
'pathLength', 'messageBytes', 'hourOfDay', 'isFlood', 'deliveryMs',
|
||||
]);
|
||||
|
||||
final model = LinearRegressor(data, 'deliveryMs');
|
||||
|
||||
for (final tc in [
|
||||
[0.0, 50.0, 14.0, 0.0],
|
||||
[2.0, 50.0, 14.0, 0.0],
|
||||
[4.0, 50.0, 14.0, 0.0],
|
||||
[-1.0, 50.0, 14.0, 1.0],
|
||||
]) {
|
||||
final testDf = DataFrame(
|
||||
[tc],
|
||||
headerExists: false,
|
||||
header: ['pathLength', 'messageBytes', 'hourOfDay', 'isFlood'],
|
||||
);
|
||||
final pred = model.predict(testDf).rows.first.first;
|
||||
debugPrint('4-feature: hops=${tc[0]} flood=${tc[3]} → ${(pred as num).round()}ms');
|
||||
}
|
||||
});
|
||||
}
|
||||
164
test/services/timeout_prediction_service_test.dart
Normal file
164
test/services/timeout_prediction_service_test.dart
Normal file
|
|
@ -0,0 +1,164 @@
|
|||
import 'package:flutter/foundation.dart';
|
||||
import 'package:flutter_test/flutter_test.dart';
|
||||
import 'package:meshcore_open/models/delivery_observation.dart';
|
||||
import 'package:meshcore_open/services/timeout_prediction_service.dart';
|
||||
|
||||
void main() {
|
||||
late TimeoutPredictionService service;
|
||||
|
||||
setUp(() {
|
||||
service = TimeoutPredictionService.noStorage();
|
||||
});
|
||||
|
||||
test('trains on sample data and predicts sensible timeouts', () {
|
||||
// Simulate realistic delivery data:
|
||||
// Direct 0-hop messages: ~1500-2500ms
|
||||
// 2-hop messages: ~4000-6000ms
|
||||
// 4-hop messages: ~8000-12000ms
|
||||
// Flood messages: ~3000-8000ms
|
||||
final sampleData = [
|
||||
// 0-hop direct
|
||||
_obs(pathLength: 0, messageBytes: 20, deliveryMs: 1800),
|
||||
_obs(pathLength: 0, messageBytes: 50, deliveryMs: 2100),
|
||||
_obs(pathLength: 0, messageBytes: 80, deliveryMs: 2400),
|
||||
_obs(pathLength: 0, messageBytes: 30, deliveryMs: 1925),
|
||||
// 2-hop direct
|
||||
_obs(pathLength: 2, messageBytes: 40, deliveryMs: 4500),
|
||||
_obs(pathLength: 2, messageBytes: 60, deliveryMs: 5200),
|
||||
_obs(pathLength: 2, messageBytes: 25, deliveryMs: 4100),
|
||||
// 4-hop direct
|
||||
_obs(pathLength: 4, messageBytes: 50, deliveryMs: 9800),
|
||||
_obs(pathLength: 4, messageBytes: 30, deliveryMs: 8500),
|
||||
_obs(pathLength: 4, messageBytes: 70, deliveryMs: 10570),
|
||||
// Flood
|
||||
_obs(pathLength: -1, messageBytes: 40, deliveryMs: 5000),
|
||||
_obs(pathLength: -1, messageBytes: 60, deliveryMs: 6500),
|
||||
];
|
||||
|
||||
// Feed all observations
|
||||
for (final obs in sampleData) {
|
||||
service.recordObservation(
|
||||
contactKey: obs.contactKey,
|
||||
pathLength: obs.pathLength,
|
||||
messageBytes: obs.messageBytes,
|
||||
tripTimeMs: obs.deliveryMs,
|
||||
);
|
||||
}
|
||||
|
||||
expect(service.hasModel, isTrue);
|
||||
expect(service.observationCount, equals(12));
|
||||
|
||||
// Predict for different scenarios
|
||||
final direct0 = service.predictTimeout(pathLength: 0, messageBytes: 50);
|
||||
final direct2 = service.predictTimeout(pathLength: 2, messageBytes: 50);
|
||||
final direct4 = service.predictTimeout(pathLength: 4, messageBytes: 50);
|
||||
final flood = service.predictTimeout(pathLength: -1, messageBytes: 50);
|
||||
|
||||
// All should return non-null (model is trained)
|
||||
expect(direct0, isNotNull);
|
||||
expect(direct2, isNotNull);
|
||||
expect(direct4, isNotNull);
|
||||
expect(flood, isNotNull);
|
||||
|
||||
// More hops should predict longer timeouts
|
||||
expect(direct4!, greaterThan(direct2!));
|
||||
expect(direct2, greaterThan(direct0!));
|
||||
|
||||
// All should be within the clamp range
|
||||
expect(direct0, greaterThanOrEqualTo(2000));
|
||||
expect(direct4, lessThanOrEqualTo(120000));
|
||||
|
||||
// Print predictions for visibility
|
||||
debugPrint('Predictions (with 1.5x safety margin):');
|
||||
debugPrint(' 0-hop direct: ${direct0}ms');
|
||||
debugPrint(' 2-hop direct: ${direct2}ms');
|
||||
debugPrint(' 4-hop direct: ${direct4}ms');
|
||||
debugPrint(' flood: ${flood}ms');
|
||||
});
|
||||
|
||||
test('returns null before minimum observations', () {
|
||||
for (var i = 0; i < TimeoutPredictionService.minObservations - 1; i++) {
|
||||
service.recordObservation(
|
||||
contactKey: 'abc',
|
||||
pathLength: 0,
|
||||
messageBytes: 50,
|
||||
tripTimeMs: 2000,
|
||||
);
|
||||
}
|
||||
|
||||
expect(service.hasModel, isFalse);
|
||||
expect(service.predictTimeout(pathLength: 0, messageBytes: 50), isNull);
|
||||
});
|
||||
|
||||
test('caps observations at maxObservations', () {
|
||||
for (var i = 0; i < TimeoutPredictionService.maxObservations + 20; i++) {
|
||||
service.recordObservation(
|
||||
contactKey: 'abc',
|
||||
pathLength: 0,
|
||||
messageBytes: 50,
|
||||
tripTimeMs: 2000 + i,
|
||||
);
|
||||
}
|
||||
|
||||
expect(
|
||||
service.observationCount,
|
||||
equals(TimeoutPredictionService.maxObservations),
|
||||
);
|
||||
});
|
||||
|
||||
test('blends per-contact stats after enough observations', () {
|
||||
// Train with mixed contacts and varied features:
|
||||
// contactA is fast (0-hop), contactB is slow (2-hop)
|
||||
for (var i = 0; i < 12; i++) {
|
||||
service.recordObservation(
|
||||
contactKey: 'contactA',
|
||||
pathLength: 0,
|
||||
messageBytes: 30 + i,
|
||||
tripTimeMs: 1500,
|
||||
);
|
||||
service.recordObservation(
|
||||
contactKey: 'contactB',
|
||||
pathLength: 2,
|
||||
messageBytes: 30 + i,
|
||||
tripTimeMs: 8000,
|
||||
);
|
||||
}
|
||||
|
||||
final predA = service.predictTimeout(
|
||||
contactKey: 'contactA',
|
||||
pathLength: 0,
|
||||
messageBytes: 50,
|
||||
);
|
||||
final predB = service.predictTimeout(
|
||||
contactKey: 'contactB',
|
||||
pathLength: 0,
|
||||
messageBytes: 50,
|
||||
);
|
||||
|
||||
expect(predA, isNotNull);
|
||||
expect(predB, isNotNull);
|
||||
// Contact B (slow) should have a higher predicted timeout than A (fast)
|
||||
expect(predB!, greaterThan(predA!));
|
||||
|
||||
debugPrint('Per-contact blending:');
|
||||
debugPrint(' contactA (fast): ${predA}ms');
|
||||
debugPrint(' contactB (slow): ${predB}ms');
|
||||
});
|
||||
}
|
||||
|
||||
DeliveryObservation _obs({
|
||||
required int pathLength,
|
||||
required int messageBytes,
|
||||
required int deliveryMs,
|
||||
String contactKey = 'test_contact',
|
||||
}) {
|
||||
return DeliveryObservation(
|
||||
contactKey: contactKey,
|
||||
pathLength: pathLength,
|
||||
messageBytes: messageBytes,
|
||||
secondsSinceLastRx: 5,
|
||||
isFlood: pathLength < 0,
|
||||
deliveryMs: deliveryMs,
|
||||
timestamp: DateTime.now(),
|
||||
);
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue