feat: add ML-based adaptive timeout prediction using LinearRegressor

Train a linear regression model on actual message delivery times to
predict tighter timeouts, replacing worst-case physics estimates.
Features: path length, message bytes, seconds since last RX, flood mode.
Global model with per-contact blending after 10+ observations per contact.
Falls back to existing physics formula when model has insufficient data.
This commit is contained in:
zjs81 2026-03-14 16:56:11 -07:00
parent 8b280b37be
commit 2ee2358ecc
9 changed files with 683 additions and 20 deletions

View file

@ -0,0 +1,122 @@
import 'package:flutter/foundation.dart';
import 'package:flutter_test/flutter_test.dart';
import 'package:ml_algo/ml_algo.dart';
import 'package:ml_dataframe/ml_dataframe.dart';
void main() {
test('LinearRegressor basic sanity check', () {
// Simple: y = 2x + 100
final data = DataFrame([
[1.0, 102.0],
[2.0, 104.0],
[3.0, 106.0],
[4.0, 108.0],
[5.0, 110.0],
[10.0, 120.0],
[20.0, 140.0],
[50.0, 200.0],
[0.0, 100.0],
[100.0, 300.0],
], headerExists: false, header: ['x', 'y']);
debugPrint('Training data columns: ${data.header}');
debugPrint('Training data rows: ${data.rows.length}');
final model = LinearRegressor(data, 'y');
final testDf = DataFrame(
[[25.0]],
headerExists: false,
header: ['x'],
);
final prediction = model.predict(testDf);
final value = prediction.rows.first.first;
debugPrint('Predict x=25 → y=$value (expected ~150)');
expect((value as num).toDouble(), closeTo(150, 5));
});
test('LinearRegressor multi-feature with constant column produces zeros', () {
// isFlood=0 for all rows zero-variance column singular matrix
final data = DataFrame([
[0.0, 50.0, 14.0, 0.0, 1900.0],
[0.0, 80.0, 14.0, 0.0, 2200.0],
[2.0, 50.0, 14.0, 0.0, 5000.0],
[4.0, 50.0, 14.0, 0.0, 9500.0],
], headerExists: false, header: [
'pathLength', 'messageBytes', 'hourOfDay', 'isFlood', 'deliveryMs',
]);
final model = LinearRegressor(data, 'deliveryMs');
final testDf = DataFrame(
[[2.0, 50.0, 14.0, 0.0]],
headerExists: false,
header: ['pathLength', 'messageBytes', 'hourOfDay', 'isFlood'],
);
final pred = model.predict(testDf).rows.first.first;
debugPrint('With constant isFlood column: hops=2 → ${(pred as num).round()}ms (likely 0)');
});
test('LinearRegressor 2-feature works correctly', () {
// Just pathLength + messageBytes deliveryMs
final data = DataFrame([
[0.0, 50.0, 1900.0],
[0.0, 80.0, 2200.0],
[2.0, 50.0, 5000.0],
[2.0, 80.0, 5500.0],
[4.0, 50.0, 9500.0],
[4.0, 80.0, 10000.0],
[0.0, 30.0, 1800.0],
[2.0, 30.0, 4800.0],
[4.0, 30.0, 9000.0],
[0.0, 60.0, 2000.0],
], headerExists: false, header: ['pathLength', 'messageBytes', 'deliveryMs']);
final model = LinearRegressor(data, 'deliveryMs');
for (final hops in [0.0, 2.0, 4.0]) {
final testDf = DataFrame(
[[hops, 50.0]],
headerExists: false,
header: ['pathLength', 'messageBytes'],
);
final pred = model.predict(testDf).rows.first.first;
debugPrint('2-feature: hops=$hops${(pred as num).round()}ms');
}
});
test('LinearRegressor multi-feature with variance in all columns', () {
// Mix flood and direct so isFlood has variance
final data = DataFrame([
[0.0, 50.0, 14.0, 0.0, 1900.0],
[0.0, 80.0, 10.0, 0.0, 2200.0],
[2.0, 50.0, 16.0, 0.0, 5000.0],
[2.0, 80.0, 20.0, 0.0, 5500.0],
[4.0, 50.0, 8.0, 0.0, 9500.0],
[4.0, 80.0, 12.0, 0.0, 10000.0],
[-1.0, 40.0, 14.0, 1.0, 5000.0],
[-1.0, 60.0, 18.0, 1.0, 6500.0],
[-1.0, 30.0, 10.0, 1.0, 4000.0],
[-1.0, 80.0, 22.0, 1.0, 7000.0],
], headerExists: false, header: [
'pathLength', 'messageBytes', 'hourOfDay', 'isFlood', 'deliveryMs',
]);
final model = LinearRegressor(data, 'deliveryMs');
for (final tc in [
[0.0, 50.0, 14.0, 0.0],
[2.0, 50.0, 14.0, 0.0],
[4.0, 50.0, 14.0, 0.0],
[-1.0, 50.0, 14.0, 1.0],
]) {
final testDf = DataFrame(
[tc],
headerExists: false,
header: ['pathLength', 'messageBytes', 'hourOfDay', 'isFlood'],
);
final pred = model.predict(testDf).rows.first.first;
debugPrint('4-feature: hops=${tc[0]} flood=${tc[3]}${(pred as num).round()}ms');
}
});
}

View file

@ -0,0 +1,164 @@
import 'package:flutter/foundation.dart';
import 'package:flutter_test/flutter_test.dart';
import 'package:meshcore_open/models/delivery_observation.dart';
import 'package:meshcore_open/services/timeout_prediction_service.dart';
void main() {
late TimeoutPredictionService service;
setUp(() {
service = TimeoutPredictionService.noStorage();
});
test('trains on sample data and predicts sensible timeouts', () {
// Simulate realistic delivery data:
// Direct 0-hop messages: ~1500-2500ms
// 2-hop messages: ~4000-6000ms
// 4-hop messages: ~8000-12000ms
// Flood messages: ~3000-8000ms
final sampleData = [
// 0-hop direct
_obs(pathLength: 0, messageBytes: 20, deliveryMs: 1800),
_obs(pathLength: 0, messageBytes: 50, deliveryMs: 2100),
_obs(pathLength: 0, messageBytes: 80, deliveryMs: 2400),
_obs(pathLength: 0, messageBytes: 30, deliveryMs: 1925),
// 2-hop direct
_obs(pathLength: 2, messageBytes: 40, deliveryMs: 4500),
_obs(pathLength: 2, messageBytes: 60, deliveryMs: 5200),
_obs(pathLength: 2, messageBytes: 25, deliveryMs: 4100),
// 4-hop direct
_obs(pathLength: 4, messageBytes: 50, deliveryMs: 9800),
_obs(pathLength: 4, messageBytes: 30, deliveryMs: 8500),
_obs(pathLength: 4, messageBytes: 70, deliveryMs: 10570),
// Flood
_obs(pathLength: -1, messageBytes: 40, deliveryMs: 5000),
_obs(pathLength: -1, messageBytes: 60, deliveryMs: 6500),
];
// Feed all observations
for (final obs in sampleData) {
service.recordObservation(
contactKey: obs.contactKey,
pathLength: obs.pathLength,
messageBytes: obs.messageBytes,
tripTimeMs: obs.deliveryMs,
);
}
expect(service.hasModel, isTrue);
expect(service.observationCount, equals(12));
// Predict for different scenarios
final direct0 = service.predictTimeout(pathLength: 0, messageBytes: 50);
final direct2 = service.predictTimeout(pathLength: 2, messageBytes: 50);
final direct4 = service.predictTimeout(pathLength: 4, messageBytes: 50);
final flood = service.predictTimeout(pathLength: -1, messageBytes: 50);
// All should return non-null (model is trained)
expect(direct0, isNotNull);
expect(direct2, isNotNull);
expect(direct4, isNotNull);
expect(flood, isNotNull);
// More hops should predict longer timeouts
expect(direct4!, greaterThan(direct2!));
expect(direct2, greaterThan(direct0!));
// All should be within the clamp range
expect(direct0, greaterThanOrEqualTo(2000));
expect(direct4, lessThanOrEqualTo(120000));
// Print predictions for visibility
debugPrint('Predictions (with 1.5x safety margin):');
debugPrint(' 0-hop direct: ${direct0}ms');
debugPrint(' 2-hop direct: ${direct2}ms');
debugPrint(' 4-hop direct: ${direct4}ms');
debugPrint(' flood: ${flood}ms');
});
test('returns null before minimum observations', () {
for (var i = 0; i < TimeoutPredictionService.minObservations - 1; i++) {
service.recordObservation(
contactKey: 'abc',
pathLength: 0,
messageBytes: 50,
tripTimeMs: 2000,
);
}
expect(service.hasModel, isFalse);
expect(service.predictTimeout(pathLength: 0, messageBytes: 50), isNull);
});
test('caps observations at maxObservations', () {
for (var i = 0; i < TimeoutPredictionService.maxObservations + 20; i++) {
service.recordObservation(
contactKey: 'abc',
pathLength: 0,
messageBytes: 50,
tripTimeMs: 2000 + i,
);
}
expect(
service.observationCount,
equals(TimeoutPredictionService.maxObservations),
);
});
test('blends per-contact stats after enough observations', () {
// Train with mixed contacts and varied features:
// contactA is fast (0-hop), contactB is slow (2-hop)
for (var i = 0; i < 12; i++) {
service.recordObservation(
contactKey: 'contactA',
pathLength: 0,
messageBytes: 30 + i,
tripTimeMs: 1500,
);
service.recordObservation(
contactKey: 'contactB',
pathLength: 2,
messageBytes: 30 + i,
tripTimeMs: 8000,
);
}
final predA = service.predictTimeout(
contactKey: 'contactA',
pathLength: 0,
messageBytes: 50,
);
final predB = service.predictTimeout(
contactKey: 'contactB',
pathLength: 0,
messageBytes: 50,
);
expect(predA, isNotNull);
expect(predB, isNotNull);
// Contact B (slow) should have a higher predicted timeout than A (fast)
expect(predB!, greaterThan(predA!));
debugPrint('Per-contact blending:');
debugPrint(' contactA (fast): ${predA}ms');
debugPrint(' contactB (slow): ${predB}ms');
});
}
DeliveryObservation _obs({
required int pathLength,
required int messageBytes,
required int deliveryMs,
String contactKey = 'test_contact',
}) {
return DeliveryObservation(
contactKey: contactKey,
pathLength: pathLength,
messageBytes: messageBytes,
secondsSinceLastRx: 5,
isFlood: pathLength < 0,
deliveryMs: deliveryMs,
timestamp: DateTime.now(),
);
}