mirror of
https://github.com/zjs81/meshcore-open.git
synced 2026-04-20 22:13:48 +00:00
feat: add ML-based adaptive timeout prediction using LinearRegressor
Train a linear regression model on actual message delivery times to predict tighter timeouts, replacing worst-case physics estimates. Features: path length, message bytes, seconds since last RX, flood mode. Global model with per-contact blending after 10+ observations per contact. Falls back to existing physics formula when model has insufficient data.
This commit is contained in:
parent
8b280b37be
commit
2ee2358ecc
9 changed files with 683 additions and 20 deletions
122
test/services/ml_algo_sanity_test.dart
Normal file
122
test/services/ml_algo_sanity_test.dart
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
import 'package:flutter/foundation.dart';
|
||||
import 'package:flutter_test/flutter_test.dart';
|
||||
import 'package:ml_algo/ml_algo.dart';
|
||||
import 'package:ml_dataframe/ml_dataframe.dart';
|
||||
|
||||
void main() {
|
||||
test('LinearRegressor basic sanity check', () {
|
||||
// Simple: y = 2x + 100
|
||||
final data = DataFrame([
|
||||
[1.0, 102.0],
|
||||
[2.0, 104.0],
|
||||
[3.0, 106.0],
|
||||
[4.0, 108.0],
|
||||
[5.0, 110.0],
|
||||
[10.0, 120.0],
|
||||
[20.0, 140.0],
|
||||
[50.0, 200.0],
|
||||
[0.0, 100.0],
|
||||
[100.0, 300.0],
|
||||
], headerExists: false, header: ['x', 'y']);
|
||||
|
||||
debugPrint('Training data columns: ${data.header}');
|
||||
debugPrint('Training data rows: ${data.rows.length}');
|
||||
|
||||
final model = LinearRegressor(data, 'y');
|
||||
|
||||
final testDf = DataFrame(
|
||||
[[25.0]],
|
||||
headerExists: false,
|
||||
header: ['x'],
|
||||
);
|
||||
|
||||
final prediction = model.predict(testDf);
|
||||
final value = prediction.rows.first.first;
|
||||
debugPrint('Predict x=25 → y=$value (expected ~150)');
|
||||
expect((value as num).toDouble(), closeTo(150, 5));
|
||||
});
|
||||
|
||||
test('LinearRegressor multi-feature with constant column produces zeros', () {
|
||||
// isFlood=0 for all rows → zero-variance column → singular matrix
|
||||
final data = DataFrame([
|
||||
[0.0, 50.0, 14.0, 0.0, 1900.0],
|
||||
[0.0, 80.0, 14.0, 0.0, 2200.0],
|
||||
[2.0, 50.0, 14.0, 0.0, 5000.0],
|
||||
[4.0, 50.0, 14.0, 0.0, 9500.0],
|
||||
], headerExists: false, header: [
|
||||
'pathLength', 'messageBytes', 'hourOfDay', 'isFlood', 'deliveryMs',
|
||||
]);
|
||||
|
||||
final model = LinearRegressor(data, 'deliveryMs');
|
||||
final testDf = DataFrame(
|
||||
[[2.0, 50.0, 14.0, 0.0]],
|
||||
headerExists: false,
|
||||
header: ['pathLength', 'messageBytes', 'hourOfDay', 'isFlood'],
|
||||
);
|
||||
final pred = model.predict(testDf).rows.first.first;
|
||||
debugPrint('With constant isFlood column: hops=2 → ${(pred as num).round()}ms (likely 0)');
|
||||
});
|
||||
|
||||
test('LinearRegressor 2-feature works correctly', () {
|
||||
// Just pathLength + messageBytes → deliveryMs
|
||||
final data = DataFrame([
|
||||
[0.0, 50.0, 1900.0],
|
||||
[0.0, 80.0, 2200.0],
|
||||
[2.0, 50.0, 5000.0],
|
||||
[2.0, 80.0, 5500.0],
|
||||
[4.0, 50.0, 9500.0],
|
||||
[4.0, 80.0, 10000.0],
|
||||
[0.0, 30.0, 1800.0],
|
||||
[2.0, 30.0, 4800.0],
|
||||
[4.0, 30.0, 9000.0],
|
||||
[0.0, 60.0, 2000.0],
|
||||
], headerExists: false, header: ['pathLength', 'messageBytes', 'deliveryMs']);
|
||||
|
||||
final model = LinearRegressor(data, 'deliveryMs');
|
||||
|
||||
for (final hops in [0.0, 2.0, 4.0]) {
|
||||
final testDf = DataFrame(
|
||||
[[hops, 50.0]],
|
||||
headerExists: false,
|
||||
header: ['pathLength', 'messageBytes'],
|
||||
);
|
||||
final pred = model.predict(testDf).rows.first.first;
|
||||
debugPrint('2-feature: hops=$hops → ${(pred as num).round()}ms');
|
||||
}
|
||||
});
|
||||
|
||||
test('LinearRegressor multi-feature with variance in all columns', () {
|
||||
// Mix flood and direct so isFlood has variance
|
||||
final data = DataFrame([
|
||||
[0.0, 50.0, 14.0, 0.0, 1900.0],
|
||||
[0.0, 80.0, 10.0, 0.0, 2200.0],
|
||||
[2.0, 50.0, 16.0, 0.0, 5000.0],
|
||||
[2.0, 80.0, 20.0, 0.0, 5500.0],
|
||||
[4.0, 50.0, 8.0, 0.0, 9500.0],
|
||||
[4.0, 80.0, 12.0, 0.0, 10000.0],
|
||||
[-1.0, 40.0, 14.0, 1.0, 5000.0],
|
||||
[-1.0, 60.0, 18.0, 1.0, 6500.0],
|
||||
[-1.0, 30.0, 10.0, 1.0, 4000.0],
|
||||
[-1.0, 80.0, 22.0, 1.0, 7000.0],
|
||||
], headerExists: false, header: [
|
||||
'pathLength', 'messageBytes', 'hourOfDay', 'isFlood', 'deliveryMs',
|
||||
]);
|
||||
|
||||
final model = LinearRegressor(data, 'deliveryMs');
|
||||
|
||||
for (final tc in [
|
||||
[0.0, 50.0, 14.0, 0.0],
|
||||
[2.0, 50.0, 14.0, 0.0],
|
||||
[4.0, 50.0, 14.0, 0.0],
|
||||
[-1.0, 50.0, 14.0, 1.0],
|
||||
]) {
|
||||
final testDf = DataFrame(
|
||||
[tc],
|
||||
headerExists: false,
|
||||
header: ['pathLength', 'messageBytes', 'hourOfDay', 'isFlood'],
|
||||
);
|
||||
final pred = model.predict(testDf).rows.first.first;
|
||||
debugPrint('4-feature: hops=${tc[0]} flood=${tc[3]} → ${(pred as num).round()}ms');
|
||||
}
|
||||
});
|
||||
}
|
||||
164
test/services/timeout_prediction_service_test.dart
Normal file
164
test/services/timeout_prediction_service_test.dart
Normal file
|
|
@ -0,0 +1,164 @@
|
|||
import 'package:flutter/foundation.dart';
|
||||
import 'package:flutter_test/flutter_test.dart';
|
||||
import 'package:meshcore_open/models/delivery_observation.dart';
|
||||
import 'package:meshcore_open/services/timeout_prediction_service.dart';
|
||||
|
||||
void main() {
|
||||
late TimeoutPredictionService service;
|
||||
|
||||
setUp(() {
|
||||
service = TimeoutPredictionService.noStorage();
|
||||
});
|
||||
|
||||
test('trains on sample data and predicts sensible timeouts', () {
|
||||
// Simulate realistic delivery data:
|
||||
// Direct 0-hop messages: ~1500-2500ms
|
||||
// 2-hop messages: ~4000-6000ms
|
||||
// 4-hop messages: ~8000-12000ms
|
||||
// Flood messages: ~3000-8000ms
|
||||
final sampleData = [
|
||||
// 0-hop direct
|
||||
_obs(pathLength: 0, messageBytes: 20, deliveryMs: 1800),
|
||||
_obs(pathLength: 0, messageBytes: 50, deliveryMs: 2100),
|
||||
_obs(pathLength: 0, messageBytes: 80, deliveryMs: 2400),
|
||||
_obs(pathLength: 0, messageBytes: 30, deliveryMs: 1925),
|
||||
// 2-hop direct
|
||||
_obs(pathLength: 2, messageBytes: 40, deliveryMs: 4500),
|
||||
_obs(pathLength: 2, messageBytes: 60, deliveryMs: 5200),
|
||||
_obs(pathLength: 2, messageBytes: 25, deliveryMs: 4100),
|
||||
// 4-hop direct
|
||||
_obs(pathLength: 4, messageBytes: 50, deliveryMs: 9800),
|
||||
_obs(pathLength: 4, messageBytes: 30, deliveryMs: 8500),
|
||||
_obs(pathLength: 4, messageBytes: 70, deliveryMs: 10570),
|
||||
// Flood
|
||||
_obs(pathLength: -1, messageBytes: 40, deliveryMs: 5000),
|
||||
_obs(pathLength: -1, messageBytes: 60, deliveryMs: 6500),
|
||||
];
|
||||
|
||||
// Feed all observations
|
||||
for (final obs in sampleData) {
|
||||
service.recordObservation(
|
||||
contactKey: obs.contactKey,
|
||||
pathLength: obs.pathLength,
|
||||
messageBytes: obs.messageBytes,
|
||||
tripTimeMs: obs.deliveryMs,
|
||||
);
|
||||
}
|
||||
|
||||
expect(service.hasModel, isTrue);
|
||||
expect(service.observationCount, equals(12));
|
||||
|
||||
// Predict for different scenarios
|
||||
final direct0 = service.predictTimeout(pathLength: 0, messageBytes: 50);
|
||||
final direct2 = service.predictTimeout(pathLength: 2, messageBytes: 50);
|
||||
final direct4 = service.predictTimeout(pathLength: 4, messageBytes: 50);
|
||||
final flood = service.predictTimeout(pathLength: -1, messageBytes: 50);
|
||||
|
||||
// All should return non-null (model is trained)
|
||||
expect(direct0, isNotNull);
|
||||
expect(direct2, isNotNull);
|
||||
expect(direct4, isNotNull);
|
||||
expect(flood, isNotNull);
|
||||
|
||||
// More hops should predict longer timeouts
|
||||
expect(direct4!, greaterThan(direct2!));
|
||||
expect(direct2, greaterThan(direct0!));
|
||||
|
||||
// All should be within the clamp range
|
||||
expect(direct0, greaterThanOrEqualTo(2000));
|
||||
expect(direct4, lessThanOrEqualTo(120000));
|
||||
|
||||
// Print predictions for visibility
|
||||
debugPrint('Predictions (with 1.5x safety margin):');
|
||||
debugPrint(' 0-hop direct: ${direct0}ms');
|
||||
debugPrint(' 2-hop direct: ${direct2}ms');
|
||||
debugPrint(' 4-hop direct: ${direct4}ms');
|
||||
debugPrint(' flood: ${flood}ms');
|
||||
});
|
||||
|
||||
test('returns null before minimum observations', () {
|
||||
for (var i = 0; i < TimeoutPredictionService.minObservations - 1; i++) {
|
||||
service.recordObservation(
|
||||
contactKey: 'abc',
|
||||
pathLength: 0,
|
||||
messageBytes: 50,
|
||||
tripTimeMs: 2000,
|
||||
);
|
||||
}
|
||||
|
||||
expect(service.hasModel, isFalse);
|
||||
expect(service.predictTimeout(pathLength: 0, messageBytes: 50), isNull);
|
||||
});
|
||||
|
||||
test('caps observations at maxObservations', () {
|
||||
for (var i = 0; i < TimeoutPredictionService.maxObservations + 20; i++) {
|
||||
service.recordObservation(
|
||||
contactKey: 'abc',
|
||||
pathLength: 0,
|
||||
messageBytes: 50,
|
||||
tripTimeMs: 2000 + i,
|
||||
);
|
||||
}
|
||||
|
||||
expect(
|
||||
service.observationCount,
|
||||
equals(TimeoutPredictionService.maxObservations),
|
||||
);
|
||||
});
|
||||
|
||||
test('blends per-contact stats after enough observations', () {
|
||||
// Train with mixed contacts and varied features:
|
||||
// contactA is fast (0-hop), contactB is slow (2-hop)
|
||||
for (var i = 0; i < 12; i++) {
|
||||
service.recordObservation(
|
||||
contactKey: 'contactA',
|
||||
pathLength: 0,
|
||||
messageBytes: 30 + i,
|
||||
tripTimeMs: 1500,
|
||||
);
|
||||
service.recordObservation(
|
||||
contactKey: 'contactB',
|
||||
pathLength: 2,
|
||||
messageBytes: 30 + i,
|
||||
tripTimeMs: 8000,
|
||||
);
|
||||
}
|
||||
|
||||
final predA = service.predictTimeout(
|
||||
contactKey: 'contactA',
|
||||
pathLength: 0,
|
||||
messageBytes: 50,
|
||||
);
|
||||
final predB = service.predictTimeout(
|
||||
contactKey: 'contactB',
|
||||
pathLength: 0,
|
||||
messageBytes: 50,
|
||||
);
|
||||
|
||||
expect(predA, isNotNull);
|
||||
expect(predB, isNotNull);
|
||||
// Contact B (slow) should have a higher predicted timeout than A (fast)
|
||||
expect(predB!, greaterThan(predA!));
|
||||
|
||||
debugPrint('Per-contact blending:');
|
||||
debugPrint(' contactA (fast): ${predA}ms');
|
||||
debugPrint(' contactB (slow): ${predB}ms');
|
||||
});
|
||||
}
|
||||
|
||||
DeliveryObservation _obs({
|
||||
required int pathLength,
|
||||
required int messageBytes,
|
||||
required int deliveryMs,
|
||||
String contactKey = 'test_contact',
|
||||
}) {
|
||||
return DeliveryObservation(
|
||||
contactKey: contactKey,
|
||||
pathLength: pathLength,
|
||||
messageBytes: messageBytes,
|
||||
secondsSinceLastRx: 5,
|
||||
isFlood: pathLength < 0,
|
||||
deliveryMs: deliveryMs,
|
||||
timestamp: DateTime.now(),
|
||||
);
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue