diff --git a/lib/connector/meshcore_connector.dart b/lib/connector/meshcore_connector.dart index 7cf32ef..d05a8f9 100644 --- a/lib/connector/meshcore_connector.dart +++ b/lib/connector/meshcore_connector.dart @@ -19,6 +19,7 @@ import '../services/message_retry_service.dart'; import '../services/path_history_service.dart'; import '../services/app_settings_service.dart'; import '../services/background_service.dart'; +import '../services/timeout_prediction_service.dart'; import '../services/notification_service.dart'; import 'meshcore_connector_usb.dart'; import 'meshcore_connector_tcp.dart'; @@ -166,6 +167,8 @@ class MeshCoreConnector extends ChangeNotifier { bool _isLoadingContacts = false; bool _isLoadingChannels = false; bool _hasLoadedChannels = false; + TimeoutPredictionService? _timeoutPredictionService; + DateTime _lastRxTime = DateTime.now(); bool _batteryRequested = false; bool _awaitingSelfInfo = false; bool _hasReceivedDeviceInfo = false; @@ -668,6 +671,7 @@ class MeshCoreConnector extends ChangeNotifier { BleDebugLogService? bleDebugLogService, AppDebugLogService? appDebugLogService, BackgroundService? backgroundService, + TimeoutPredictionService? timeoutPredictionService, }) { _retryService = retryService; _pathHistoryService = pathHistoryService; @@ -675,6 +679,7 @@ class MeshCoreConnector extends ChangeNotifier { _bleDebugLogService = bleDebugLogService; _appDebugLogService = appDebugLogService; _backgroundService = backgroundService; + _timeoutPredictionService = timeoutPredictionService; _usbManager.setDebugLogService(_appDebugLogService); _tcpConnector.setDebugLogService(_appDebugLogService); @@ -689,13 +694,23 @@ class MeshCoreConnector extends ChangeNotifier { updateMessageCallback: _updateMessage, clearContactPathCallback: clearContactPath, setContactPathCallback: setContactPath, - calculateTimeoutCallback: (pathLength, messageBytes) => - calculateTimeout(pathLength: pathLength, messageBytes: messageBytes), + calculateTimeoutCallback: (pathLength, messageBytes, {String? contactKey}) => + calculateTimeout(pathLength: pathLength, messageBytes: messageBytes, contactKey: contactKey), getSelfPublicKeyCallback: () => _selfPublicKey, prepareContactOutboundTextCallback: prepareContactOutboundText, appSettingsService: appSettingsService, debugLogService: _appDebugLogService, recordPathResultCallback: _recordPathResult, + onDeliveryObservedCallback: (contactKey, pathLength, messageBytes, tripTimeMs) { + final secSinceRx = DateTime.now().difference(_lastRxTime).inSeconds; + _timeoutPredictionService?.recordObservation( + contactKey: contactKey, + pathLength: pathLength, + messageBytes: messageBytes, + tripTimeMs: tripTimeMs, + secondsSinceLastRx: secSinceRx, + ); + }, ); } @@ -2498,6 +2513,7 @@ class MeshCoreConnector extends ChangeNotifier { void _handleFrame(List data) { if (data.isEmpty) return; + _lastRxTime = DateTime.now(); final frame = Uint8List.fromList(data); _receivedFramesController.add(frame); @@ -2876,7 +2892,21 @@ class MeshCoreConnector extends ChangeNotifier { /// Calculate timeout for a message based on radio settings and path length /// Returns timeout in milliseconds, considering number of hops - int calculateTimeout({required int pathLength, int messageBytes = 100}) { + int calculateTimeout({ + required int pathLength, + int messageBytes = 100, + String? contactKey, + }) { + // Try ML-based prediction first + final secSinceRx = DateTime.now().difference(_lastRxTime).inSeconds; + final mlTimeout = _timeoutPredictionService?.predictTimeout( + contactKey: contactKey, + pathLength: pathLength, + messageBytes: messageBytes, + secondsSinceLastRx: secSinceRx, + ); + if (mlTimeout != null) return mlTimeout; + // If we have radio settings, use them for accurate calculation if (_currentFreqHz != null && _currentBwHz != null && diff --git a/lib/main.dart b/lib/main.dart index 9e53e21..72909e2 100644 --- a/lib/main.dart +++ b/lib/main.dart @@ -19,6 +19,7 @@ import 'services/app_debug_log_service.dart'; import 'services/background_service.dart'; import 'services/map_tile_cache_service.dart'; import 'services/chat_text_scale_service.dart'; +import 'services/timeout_prediction_service.dart'; import 'storage/prefs_manager.dart'; import 'utils/app_logger.dart'; @@ -39,6 +40,7 @@ void main() async { final backgroundService = BackgroundService(); final mapTileCacheService = MapTileCacheService(); final chatTextScaleService = ChatTextScaleService(); + final timeoutPredictionService = TimeoutPredictionService(storage); // Load settings await appSettingsService.loadSettings(); @@ -56,6 +58,7 @@ void main() async { _registerThirdPartyLicenses(); await chatTextScaleService.initialize(); + await timeoutPredictionService.initialize(); // Wire up connector with services connector.initialize( @@ -65,6 +68,7 @@ void main() async { bleDebugLogService: bleDebugLogService, appDebugLogService: appDebugLogService, backgroundService: backgroundService, + timeoutPredictionService: timeoutPredictionService, ); await connector.loadContactCache(); @@ -86,6 +90,7 @@ void main() async { appDebugLogService: appDebugLogService, mapTileCacheService: mapTileCacheService, chatTextScaleService: chatTextScaleService, + timeoutPredictionService: timeoutPredictionService, ), ); } @@ -121,6 +126,7 @@ class MeshCoreApp extends StatelessWidget { final AppDebugLogService appDebugLogService; final MapTileCacheService mapTileCacheService; final ChatTextScaleService chatTextScaleService; + final TimeoutPredictionService timeoutPredictionService; const MeshCoreApp({ super.key, @@ -133,6 +139,7 @@ class MeshCoreApp extends StatelessWidget { required this.appDebugLogService, required this.mapTileCacheService, required this.chatTextScaleService, + required this.timeoutPredictionService, }); @override @@ -148,6 +155,7 @@ class MeshCoreApp extends StatelessWidget { ChangeNotifierProvider.value(value: chatTextScaleService), Provider.value(value: storage), Provider.value(value: mapTileCacheService), + ChangeNotifierProvider.value(value: timeoutPredictionService), ], child: Consumer( builder: (context, settingsService, child) { diff --git a/lib/models/delivery_observation.dart b/lib/models/delivery_observation.dart new file mode 100644 index 0000000..a598d2a --- /dev/null +++ b/lib/models/delivery_observation.dart @@ -0,0 +1,43 @@ +class DeliveryObservation { + final String contactKey; + final int pathLength; + final int messageBytes; + final int secondsSinceLastRx; + final bool isFlood; + final int deliveryMs; + final DateTime timestamp; + + DeliveryObservation({ + required this.contactKey, + required this.pathLength, + required this.messageBytes, + required this.secondsSinceLastRx, + required this.isFlood, + required this.deliveryMs, + required this.timestamp, + }); + + Map toJson() { + return { + 'contact_key': contactKey, + 'path_length': pathLength, + 'message_bytes': messageBytes, + 'seconds_since_last_rx': secondsSinceLastRx, + 'is_flood': isFlood, + 'delivery_ms': deliveryMs, + 'timestamp': timestamp.toIso8601String(), + }; + } + + factory DeliveryObservation.fromJson(Map json) { + return DeliveryObservation( + contactKey: json['contact_key'] as String, + pathLength: json['path_length'] as int, + messageBytes: json['message_bytes'] as int, + secondsSinceLastRx: json['seconds_since_last_rx'] as int? ?? 0, + isFlood: json['is_flood'] as bool, + deliveryMs: json['delivery_ms'] as int, + timestamp: DateTime.parse(json['timestamp'] as String), + ); + } +} diff --git a/lib/services/message_retry_service.dart b/lib/services/message_retry_service.dart index db4475f..d94b763 100644 --- a/lib/services/message_retry_service.dart +++ b/lib/services/message_retry_service.dart @@ -58,12 +58,13 @@ class MessageRetryService extends ChangeNotifier { Function(Message)? _updateMessageCallback; Function(Contact)? _clearContactPathCallback; Function(Contact, Uint8List, int)? _setContactPathCallback; - Function(int, int)? _calculateTimeoutCallback; + Function(int, int, {String? contactKey})? _calculateTimeoutCallback; Uint8List? Function()? _getSelfPublicKeyCallback; String Function(Contact, String)? _prepareContactOutboundTextCallback; AppSettingsService? _appSettingsService; AppDebugLogService? _debugLogService; Function(String, PathSelection, bool, int?)? _recordPathResultCallback; + Function(String, int, int, int)? _onDeliveryObservedCallback; MessageRetryService(); @@ -73,12 +74,14 @@ class MessageRetryService extends ChangeNotifier { required Function(Message) updateMessageCallback, Function(Contact)? clearContactPathCallback, Function(Contact, Uint8List, int)? setContactPathCallback, - Function(int pathLength, int messageBytes)? calculateTimeoutCallback, + Function(int pathLength, int messageBytes, {String? contactKey})? calculateTimeoutCallback, Uint8List? Function()? getSelfPublicKeyCallback, String Function(Contact, String)? prepareContactOutboundTextCallback, AppSettingsService? appSettingsService, AppDebugLogService? debugLogService, Function(String, PathSelection, bool, int?)? recordPathResultCallback, + Function(String contactKey, int pathLength, int messageBytes, int tripTimeMs)? + onDeliveryObservedCallback, }) { _sendMessageCallback = sendMessageCallback; _addMessageCallback = addMessageCallback; @@ -91,6 +94,7 @@ class MessageRetryService extends ChangeNotifier { _appSettingsService = appSettingsService; _debugLogService = debugLogService; _recordPathResultCallback = recordPathResultCallback; + _onDeliveryObservedCallback = onDeliveryObservedCallback; } /// Compute expected ACK hash using same algorithm as firmware: @@ -423,25 +427,33 @@ class MessageRetryService extends ChangeNotifier { ); } - // Use device-provided timeout, or calculate from radio settings if timeout is 0 or invalid + // Calculate timeout: prefer ML prediction, then device-provided, then physics fallback + int pathLengthValue; + if (selection != null) { + pathLengthValue = selection.useFlood ? -1 : selection.hopCount; + if (pathLengthValue < 0) pathLengthValue = contact.pathLength; + } else if (message.pathLength != null) { + pathLengthValue = message.pathLength!; + } else { + pathLengthValue = contact.pathLength; + } + int actualTimeout = timeoutMs; - if (timeoutMs <= 0 && _calculateTimeoutCallback != null) { - int pathLengthValue; - if (selection != null) { - pathLengthValue = selection.useFlood ? -1 : selection.hopCount; - if (pathLengthValue < 0) pathLengthValue = contact.pathLength; - } else if (message.pathLength != null) { - pathLengthValue = message.pathLength!; - } else { - pathLengthValue = contact.pathLength; - } - actualTimeout = _calculateTimeoutCallback!( + if (_calculateTimeoutCallback != null) { + final calculated = _calculateTimeoutCallback!( pathLengthValue, message.text.length, + contactKey: contact.publicKeyHex, ); - debugPrint( - 'Using calculated timeout: ${actualTimeout}ms for path length $pathLengthValue', - ); + // calculateTimeout tries ML first, falls back to physics. + // Use calculated value if device didn't provide one, or if ML + // produced a tighter prediction than the device's estimate. + if (timeoutMs <= 0 || calculated < timeoutMs) { + actualTimeout = calculated; + debugPrint( + 'Using calculated timeout: ${actualTimeout}ms for path length $pathLengthValue', + ); + } } final updatedMessage = message.copyWith( @@ -738,6 +750,14 @@ class MessageRetryService extends ChangeNotifier { true, tripTimeMs, ); + if (_onDeliveryObservedCallback != null && tripTimeMs > 0) { + _onDeliveryObservedCallback!( + contact.publicKeyHex, + message.pathLength ?? 0, + message.text.length, + tripTimeMs, + ); + } _onMessageResolved(matchedMessageId, contact.publicKeyHex); } diff --git a/lib/services/storage_service.dart b/lib/services/storage_service.dart index ce0c4f1..c591f64 100644 --- a/lib/services/storage_service.dart +++ b/lib/services/storage_service.dart @@ -1,4 +1,5 @@ import 'dart:convert'; +import '../models/delivery_observation.dart'; import '../models/path_history.dart'; import '../storage/prefs_manager.dart'; @@ -6,6 +7,8 @@ class StorageService { static const String _pathHistoryPrefix = 'path_history_'; static const String _pendingMessagesKey = 'pending_messages'; static const String _repeaterPasswordsKey = 'repeater_passwords'; + static const String _deliveryObservationsKey = 'delivery_observations'; + static const String _timeoutModelKey = 'timeout_ml_model'; Future savePathHistory( String contactPubKeyHex, @@ -122,4 +125,51 @@ class StorageService { final prefs = PrefsManager.instance; await prefs.remove(_repeaterPasswordsKey); } + + Future saveDeliveryObservations( + List observations, + ) async { + final prefs = PrefsManager.instance; + final jsonStr = jsonEncode(observations.map((o) => o.toJson()).toList()); + await prefs.setString(_deliveryObservationsKey, jsonStr); + } + + Future> loadDeliveryObservations() async { + final prefs = PrefsManager.instance; + final jsonStr = prefs.getString(_deliveryObservationsKey); + + if (jsonStr == null) return []; + + try { + final list = jsonDecode(jsonStr) as List; + return list + .map( + (e) => + DeliveryObservation.fromJson(e as Map), + ) + .toList(); + } catch (e) { + return []; + } + } + + Future clearDeliveryObservations() async { + final prefs = PrefsManager.instance; + await prefs.remove(_deliveryObservationsKey); + } + + Future saveTimeoutModel(String modelJson) async { + final prefs = PrefsManager.instance; + await prefs.setString(_timeoutModelKey, modelJson); + } + + Future loadTimeoutModel() async { + final prefs = PrefsManager.instance; + return prefs.getString(_timeoutModelKey); + } + + Future clearTimeoutModel() async { + final prefs = PrefsManager.instance; + await prefs.remove(_timeoutModelKey); + } } diff --git a/lib/services/timeout_prediction_service.dart b/lib/services/timeout_prediction_service.dart new file mode 100644 index 0000000..21e229e --- /dev/null +++ b/lib/services/timeout_prediction_service.dart @@ -0,0 +1,224 @@ +import 'dart:convert'; +import 'dart:math'; +import 'package:flutter/foundation.dart'; +import 'package:ml_algo/ml_algo.dart'; +import 'package:ml_dataframe/ml_dataframe.dart'; +import '../models/delivery_observation.dart'; +import 'storage_service.dart'; + +class _ContactStats { + int count = 0; + double _sum = 0; + double _sumSq = 0; + + void add(double ms) { + count++; + _sum += ms; + _sumSq += ms * ms; + } + + double get mean => _sum / count; + double get stdDev => sqrt((_sumSq / count) - (mean * mean)); +} + +class TimeoutPredictionService extends ChangeNotifier { + final StorageService? _storage; + + static const int minObservations = 10; + static const int maxObservations = 100; + static const int _retrainInterval = 5; + static const double _safetyMargin = 1.5; + static const int _minTimeoutMs = 2000; + static const int _maxTimeoutMs = 120000; + static const int _minContactObservations = 10; + + List _observations = []; + LinearRegressor? _model; + List _activeFeatures = []; + int _observationsSinceLastTrain = 0; + final Map _contactStats = {}; + + TimeoutPredictionService(StorageService storage) : _storage = storage; + TimeoutPredictionService.noStorage() : _storage = null; + + int get observationCount => _observations.length; + bool get hasModel => _model != null; + + Future initialize() async { + _observations = await _storage?.loadDeliveryObservations() ?? []; + _rebuildContactStats(); + + if (_observations.length >= minObservations) { + _trainModel(); + } + + debugPrint( + 'TimeoutPrediction: initialized with ${_observations.length} observations, ' + 'model=${_model != null ? "ready" : "waiting for data"}', + ); + } + + void recordObservation({ + required String contactKey, + required int pathLength, + required int messageBytes, + required int tripTimeMs, + int secondsSinceLastRx = 0, + }) { + final observation = DeliveryObservation( + contactKey: contactKey, + pathLength: pathLength, + messageBytes: messageBytes, + secondsSinceLastRx: secondsSinceLastRx, + isFlood: pathLength < 0, + deliveryMs: tripTimeMs, + timestamp: DateTime.now(), + ); + + _observations.add(observation); + if (_observations.length > maxObservations) { + _observations.removeAt(0); + } + + _contactStats.putIfAbsent(contactKey, () => _ContactStats()); + _contactStats[contactKey]!.add(tripTimeMs.toDouble()); + + _observationsSinceLastTrain++; + if (_observationsSinceLastTrain >= _retrainInterval && + _observations.length >= minObservations) { + _trainModel(); + } + + _storage?.saveDeliveryObservations(_observations); + debugPrint( + 'TimeoutPrediction: recorded ${tripTimeMs}ms for $pathLength hops ' + '(${_observations.length} total)', + ); + } + + int? predictTimeout({ + String? contactKey, + required int pathLength, + required int messageBytes, + int secondsSinceLastRx = 0, + }) { + if (_model == null) return null; + + try { + if (_activeFeatures.isEmpty) return null; + + final allFeatures = { + 'pathLength': pathLength.toDouble(), + 'messageBytes': messageBytes.toDouble(), + 'secSinceRx': secondsSinceLastRx.toDouble(), + 'isFlood': pathLength < 0 ? 1.0 : 0.0, + }; + final row = _activeFeatures.map((f) => allFeatures[f]!).toList(); + + final features = DataFrame( + [row], + headerExists: false, + header: _activeFeatures, + ); + + final prediction = _model!.predict(features); + final rawValue = prediction.rows.first.first; + var predictedMs = (rawValue is double) ? rawValue : (rawValue as num).toDouble(); + + debugPrint( + 'TimeoutPrediction: raw prediction=$predictedMs for ' + 'pathLength=$pathLength, messageBytes=$messageBytes, ' + 'features=$_activeFeatures', + ); + + // Sanity check: if prediction is negative or zero, fall back + if (predictedMs <= 0) return null; + + // Blend with per-contact mean if enough data + if (contactKey != null) { + final stats = _contactStats[contactKey]; + if (stats != null && stats.count >= _minContactObservations) { + predictedMs = 0.5 * predictedMs + 0.5 * stats.mean; + } + } + + final timeout = + (predictedMs * _safetyMargin).ceil().clamp(_minTimeoutMs, _maxTimeoutMs); + debugPrint( + 'TimeoutPrediction: ML timeout ${timeout}ms ' + '(raw: ${predictedMs.round()}ms, contact: $contactKey)', + ); + return timeout; + } catch (e) { + debugPrint('TimeoutPrediction: prediction failed: $e'); + return null; + } + } + + void _trainModel() { + try { + // Build feature columns, then exclude any with zero variance + // (ml_algo's OLS produces all-zero coefficients for singular matrices) + final allNames = ['pathLength', 'messageBytes', 'secSinceRx', 'isFlood']; + final allExtractors = [ + (o) => o.pathLength.toDouble(), + (o) => o.messageBytes.toDouble(), + (o) => o.secondsSinceLastRx.toDouble(), + (o) => o.isFlood ? 1.0 : 0.0, + ]; + + _activeFeatures = []; + for (var i = 0; i < allNames.length; i++) { + final values = _observations.map(allExtractors[i]).toSet(); + if (values.length > 1) _activeFeatures.add(allNames[i]); + } + + if (_activeFeatures.isEmpty) { + debugPrint('TimeoutPrediction: no features with variance, skipping training'); + return; + } + + final header = [..._activeFeatures, 'deliveryMs']; + final rows = _observations.map((o) { + final row = []; + for (var i = 0; i < allNames.length; i++) { + if (_activeFeatures.contains(allNames[i])) { + row.add(allExtractors[i](o)); + } + } + row.add(o.deliveryMs.toDouble()); + return row; + }); + + final data = DataFrame( + [header, ...rows], + headerExists: true, + ); + + _model = LinearRegressor(data, 'deliveryMs'); + _observationsSinceLastTrain = 0; + + // Log training summary with sample predictions + final avgMs = _observations.map((o) => o.deliveryMs).reduce((a, b) => a + b) / + _observations.length; + debugPrint( + 'TimeoutPrediction: trained on ${_observations.length} observations ' + '(avg: ${avgMs.round()}ms, features: $_activeFeatures)', + ); + + final modelJson = jsonEncode(_model!.toJson()); + _storage?.saveTimeoutModel(modelJson); + notifyListeners(); + } catch (e) { + debugPrint('TimeoutPrediction: training failed: $e'); + } + } + + void _rebuildContactStats() { + _contactStats.clear(); + for (final obs in _observations) { + _contactStats.putIfAbsent(obs.contactKey, () => _ContactStats()); + _contactStats[obs.contactKey]!.add(obs.deliveryMs.toDouble()); + } + } +} diff --git a/pubspec.yaml b/pubspec.yaml index 82e4d9c..4831e67 100644 --- a/pubspec.yaml +++ b/pubspec.yaml @@ -69,6 +69,8 @@ dependencies: material_symbols_icons: ^4.2906.0 web: ^1.1.1 flutter_svg: ^2.0.10+1 + ml_algo: ^16.0.0 + ml_dataframe: ^1.0.0 dev_dependencies: flutter_test: diff --git a/test/services/ml_algo_sanity_test.dart b/test/services/ml_algo_sanity_test.dart new file mode 100644 index 0000000..e4f980e --- /dev/null +++ b/test/services/ml_algo_sanity_test.dart @@ -0,0 +1,122 @@ +import 'package:flutter/foundation.dart'; +import 'package:flutter_test/flutter_test.dart'; +import 'package:ml_algo/ml_algo.dart'; +import 'package:ml_dataframe/ml_dataframe.dart'; + +void main() { + test('LinearRegressor basic sanity check', () { + // Simple: y = 2x + 100 + final data = DataFrame([ + [1.0, 102.0], + [2.0, 104.0], + [3.0, 106.0], + [4.0, 108.0], + [5.0, 110.0], + [10.0, 120.0], + [20.0, 140.0], + [50.0, 200.0], + [0.0, 100.0], + [100.0, 300.0], + ], headerExists: false, header: ['x', 'y']); + + debugPrint('Training data columns: ${data.header}'); + debugPrint('Training data rows: ${data.rows.length}'); + + final model = LinearRegressor(data, 'y'); + + final testDf = DataFrame( + [[25.0]], + headerExists: false, + header: ['x'], + ); + + final prediction = model.predict(testDf); + final value = prediction.rows.first.first; + debugPrint('Predict x=25 → y=$value (expected ~150)'); + expect((value as num).toDouble(), closeTo(150, 5)); + }); + + test('LinearRegressor multi-feature with constant column produces zeros', () { + // isFlood=0 for all rows → zero-variance column → singular matrix + final data = DataFrame([ + [0.0, 50.0, 14.0, 0.0, 1900.0], + [0.0, 80.0, 14.0, 0.0, 2200.0], + [2.0, 50.0, 14.0, 0.0, 5000.0], + [4.0, 50.0, 14.0, 0.0, 9500.0], + ], headerExists: false, header: [ + 'pathLength', 'messageBytes', 'hourOfDay', 'isFlood', 'deliveryMs', + ]); + + final model = LinearRegressor(data, 'deliveryMs'); + final testDf = DataFrame( + [[2.0, 50.0, 14.0, 0.0]], + headerExists: false, + header: ['pathLength', 'messageBytes', 'hourOfDay', 'isFlood'], + ); + final pred = model.predict(testDf).rows.first.first; + debugPrint('With constant isFlood column: hops=2 → ${(pred as num).round()}ms (likely 0)'); + }); + + test('LinearRegressor 2-feature works correctly', () { + // Just pathLength + messageBytes → deliveryMs + final data = DataFrame([ + [0.0, 50.0, 1900.0], + [0.0, 80.0, 2200.0], + [2.0, 50.0, 5000.0], + [2.0, 80.0, 5500.0], + [4.0, 50.0, 9500.0], + [4.0, 80.0, 10000.0], + [0.0, 30.0, 1800.0], + [2.0, 30.0, 4800.0], + [4.0, 30.0, 9000.0], + [0.0, 60.0, 2000.0], + ], headerExists: false, header: ['pathLength', 'messageBytes', 'deliveryMs']); + + final model = LinearRegressor(data, 'deliveryMs'); + + for (final hops in [0.0, 2.0, 4.0]) { + final testDf = DataFrame( + [[hops, 50.0]], + headerExists: false, + header: ['pathLength', 'messageBytes'], + ); + final pred = model.predict(testDf).rows.first.first; + debugPrint('2-feature: hops=$hops → ${(pred as num).round()}ms'); + } + }); + + test('LinearRegressor multi-feature with variance in all columns', () { + // Mix flood and direct so isFlood has variance + final data = DataFrame([ + [0.0, 50.0, 14.0, 0.0, 1900.0], + [0.0, 80.0, 10.0, 0.0, 2200.0], + [2.0, 50.0, 16.0, 0.0, 5000.0], + [2.0, 80.0, 20.0, 0.0, 5500.0], + [4.0, 50.0, 8.0, 0.0, 9500.0], + [4.0, 80.0, 12.0, 0.0, 10000.0], + [-1.0, 40.0, 14.0, 1.0, 5000.0], + [-1.0, 60.0, 18.0, 1.0, 6500.0], + [-1.0, 30.0, 10.0, 1.0, 4000.0], + [-1.0, 80.0, 22.0, 1.0, 7000.0], + ], headerExists: false, header: [ + 'pathLength', 'messageBytes', 'hourOfDay', 'isFlood', 'deliveryMs', + ]); + + final model = LinearRegressor(data, 'deliveryMs'); + + for (final tc in [ + [0.0, 50.0, 14.0, 0.0], + [2.0, 50.0, 14.0, 0.0], + [4.0, 50.0, 14.0, 0.0], + [-1.0, 50.0, 14.0, 1.0], + ]) { + final testDf = DataFrame( + [tc], + headerExists: false, + header: ['pathLength', 'messageBytes', 'hourOfDay', 'isFlood'], + ); + final pred = model.predict(testDf).rows.first.first; + debugPrint('4-feature: hops=${tc[0]} flood=${tc[3]} → ${(pred as num).round()}ms'); + } + }); +} diff --git a/test/services/timeout_prediction_service_test.dart b/test/services/timeout_prediction_service_test.dart new file mode 100644 index 0000000..46dc5df --- /dev/null +++ b/test/services/timeout_prediction_service_test.dart @@ -0,0 +1,164 @@ +import 'package:flutter/foundation.dart'; +import 'package:flutter_test/flutter_test.dart'; +import 'package:meshcore_open/models/delivery_observation.dart'; +import 'package:meshcore_open/services/timeout_prediction_service.dart'; + +void main() { + late TimeoutPredictionService service; + + setUp(() { + service = TimeoutPredictionService.noStorage(); + }); + + test('trains on sample data and predicts sensible timeouts', () { + // Simulate realistic delivery data: + // Direct 0-hop messages: ~1500-2500ms + // 2-hop messages: ~4000-6000ms + // 4-hop messages: ~8000-12000ms + // Flood messages: ~3000-8000ms + final sampleData = [ + // 0-hop direct + _obs(pathLength: 0, messageBytes: 20, deliveryMs: 1800), + _obs(pathLength: 0, messageBytes: 50, deliveryMs: 2100), + _obs(pathLength: 0, messageBytes: 80, deliveryMs: 2400), + _obs(pathLength: 0, messageBytes: 30, deliveryMs: 1925), + // 2-hop direct + _obs(pathLength: 2, messageBytes: 40, deliveryMs: 4500), + _obs(pathLength: 2, messageBytes: 60, deliveryMs: 5200), + _obs(pathLength: 2, messageBytes: 25, deliveryMs: 4100), + // 4-hop direct + _obs(pathLength: 4, messageBytes: 50, deliveryMs: 9800), + _obs(pathLength: 4, messageBytes: 30, deliveryMs: 8500), + _obs(pathLength: 4, messageBytes: 70, deliveryMs: 10570), + // Flood + _obs(pathLength: -1, messageBytes: 40, deliveryMs: 5000), + _obs(pathLength: -1, messageBytes: 60, deliveryMs: 6500), + ]; + + // Feed all observations + for (final obs in sampleData) { + service.recordObservation( + contactKey: obs.contactKey, + pathLength: obs.pathLength, + messageBytes: obs.messageBytes, + tripTimeMs: obs.deliveryMs, + ); + } + + expect(service.hasModel, isTrue); + expect(service.observationCount, equals(12)); + + // Predict for different scenarios + final direct0 = service.predictTimeout(pathLength: 0, messageBytes: 50); + final direct2 = service.predictTimeout(pathLength: 2, messageBytes: 50); + final direct4 = service.predictTimeout(pathLength: 4, messageBytes: 50); + final flood = service.predictTimeout(pathLength: -1, messageBytes: 50); + + // All should return non-null (model is trained) + expect(direct0, isNotNull); + expect(direct2, isNotNull); + expect(direct4, isNotNull); + expect(flood, isNotNull); + + // More hops should predict longer timeouts + expect(direct4!, greaterThan(direct2!)); + expect(direct2, greaterThan(direct0!)); + + // All should be within the clamp range + expect(direct0, greaterThanOrEqualTo(2000)); + expect(direct4, lessThanOrEqualTo(120000)); + + // Print predictions for visibility + debugPrint('Predictions (with 1.5x safety margin):'); + debugPrint(' 0-hop direct: ${direct0}ms'); + debugPrint(' 2-hop direct: ${direct2}ms'); + debugPrint(' 4-hop direct: ${direct4}ms'); + debugPrint(' flood: ${flood}ms'); + }); + + test('returns null before minimum observations', () { + for (var i = 0; i < TimeoutPredictionService.minObservations - 1; i++) { + service.recordObservation( + contactKey: 'abc', + pathLength: 0, + messageBytes: 50, + tripTimeMs: 2000, + ); + } + + expect(service.hasModel, isFalse); + expect(service.predictTimeout(pathLength: 0, messageBytes: 50), isNull); + }); + + test('caps observations at maxObservations', () { + for (var i = 0; i < TimeoutPredictionService.maxObservations + 20; i++) { + service.recordObservation( + contactKey: 'abc', + pathLength: 0, + messageBytes: 50, + tripTimeMs: 2000 + i, + ); + } + + expect( + service.observationCount, + equals(TimeoutPredictionService.maxObservations), + ); + }); + + test('blends per-contact stats after enough observations', () { + // Train with mixed contacts and varied features: + // contactA is fast (0-hop), contactB is slow (2-hop) + for (var i = 0; i < 12; i++) { + service.recordObservation( + contactKey: 'contactA', + pathLength: 0, + messageBytes: 30 + i, + tripTimeMs: 1500, + ); + service.recordObservation( + contactKey: 'contactB', + pathLength: 2, + messageBytes: 30 + i, + tripTimeMs: 8000, + ); + } + + final predA = service.predictTimeout( + contactKey: 'contactA', + pathLength: 0, + messageBytes: 50, + ); + final predB = service.predictTimeout( + contactKey: 'contactB', + pathLength: 0, + messageBytes: 50, + ); + + expect(predA, isNotNull); + expect(predB, isNotNull); + // Contact B (slow) should have a higher predicted timeout than A (fast) + expect(predB!, greaterThan(predA!)); + + debugPrint('Per-contact blending:'); + debugPrint(' contactA (fast): ${predA}ms'); + debugPrint(' contactB (slow): ${predB}ms'); + }); +} + +DeliveryObservation _obs({ + required int pathLength, + required int messageBytes, + required int deliveryMs, + String contactKey = 'test_contact', +}) { + return DeliveryObservation( + contactKey: contactKey, + pathLength: pathLength, + messageBytes: messageBytes, + secondsSinceLastRx: 5, + isFlood: pathLength < 0, + deliveryMs: deliveryMs, + timestamp: DateTime.now(), + ); +}