mirror of
https://github.com/yellowcooln/meshcore-mqtt-live-map.git
synced 2026-04-20 23:23:36 +00:00
105 lines
2.8 KiB
Python
105 lines
2.8 KiB
Python
import asyncio
|
|
import json
|
|
import time
|
|
|
|
from starlette.websockets import WebSocketDisconnect
|
|
|
|
import app
|
|
|
|
|
|
class DummyWebSocket:
|
|
def __init__(self, query_params=None, headers=None):
|
|
self.query_params = query_params or {}
|
|
self.headers = headers or {}
|
|
self.accepted = False
|
|
self.closed_code = None
|
|
self.sent_messages = []
|
|
|
|
async def accept(self):
|
|
self.accepted = True
|
|
|
|
async def close(self, code=1000):
|
|
self.closed_code = code
|
|
|
|
async def send_text(self, text):
|
|
self.sent_messages.append(text)
|
|
|
|
async def receive_text(self):
|
|
raise WebSocketDisconnect()
|
|
|
|
|
|
def test_ws_endpoint_sends_snapshot_payload(monkeypatch):
|
|
ws = DummyWebSocket()
|
|
app.clients.clear()
|
|
monkeypatch.setattr(app, "PROD_MODE", False)
|
|
|
|
asyncio.run(app.ws_endpoint(ws))
|
|
|
|
assert ws.accepted is True
|
|
assert ws.closed_code is None
|
|
assert len(ws.sent_messages) >= 1
|
|
|
|
payload = json.loads(ws.sent_messages[0])
|
|
assert payload["type"] == "snapshot"
|
|
assert "devices" in payload
|
|
assert "routes" in payload
|
|
assert ws not in app.clients
|
|
|
|
|
|
def test_ws_endpoint_rejects_unauthorized_in_prod_mode(monkeypatch):
|
|
ws = DummyWebSocket()
|
|
app.clients.clear()
|
|
monkeypatch.setattr(app, "TURNSTILE_ENABLED", False)
|
|
monkeypatch.setattr(app, "turnstile_verifier", None)
|
|
monkeypatch.setattr(app, "PROD_MODE", True)
|
|
monkeypatch.setattr(app, "PROD_TOKEN", "secret-token")
|
|
|
|
asyncio.run(app.ws_endpoint(ws))
|
|
|
|
assert ws.accepted is True
|
|
assert ws.closed_code == 1008
|
|
assert ws.sent_messages == []
|
|
assert ws not in app.clients
|
|
|
|
|
|
def test_ws_endpoint_accepts_query_token_in_prod_mode(monkeypatch):
|
|
ws = DummyWebSocket(query_params={"token": "secret-token"})
|
|
app.clients.clear()
|
|
monkeypatch.setattr(app, "TURNSTILE_ENABLED", False)
|
|
monkeypatch.setattr(app, "turnstile_verifier", None)
|
|
monkeypatch.setattr(app, "PROD_MODE", True)
|
|
monkeypatch.setattr(app, "PROD_TOKEN", "secret-token")
|
|
|
|
asyncio.run(app.ws_endpoint(ws))
|
|
|
|
assert ws.accepted is True
|
|
assert ws.closed_code is None
|
|
assert len(ws.sent_messages) >= 1
|
|
|
|
|
|
def test_ws_snapshot_omits_near_expired_routes_and_includes_server_time(monkeypatch):
|
|
ws = DummyWebSocket()
|
|
app.clients.clear()
|
|
app.routes.clear()
|
|
monkeypatch.setattr(app, "PROD_MODE", False)
|
|
now = 1000.0
|
|
monkeypatch.setattr(app.time, "time", lambda: now)
|
|
app.routes["keep"] = {
|
|
"id": "keep",
|
|
"points": [[1.0, 2.0], [3.0, 4.0]],
|
|
"expires_at": now + 30.0,
|
|
}
|
|
app.routes["drop"] = {
|
|
"id": "drop",
|
|
"points": [[1.0, 2.0], [3.0, 4.0]],
|
|
"expires_at": now + 2.0,
|
|
}
|
|
|
|
asyncio.run(app.ws_endpoint(ws))
|
|
|
|
payload = json.loads(ws.sent_messages[0])
|
|
route_ids = [route.get("id") for route in payload.get("routes", [])]
|
|
assert "keep" in route_ids
|
|
assert "drop" not in route_ids
|
|
assert payload.get("server_time") == now
|
|
app.routes.clear()
|