Fix race condition on pendingRpcs adding/pulling

This commit is contained in:
Wizou 2022-03-27 22:29:48 +02:00
parent a54cc92618
commit 5c5b8032b9
2 changed files with 73 additions and 83 deletions

View file

@ -65,8 +65,8 @@ namespace WTelegram
private readonly Random _random = new();
private int _saltChangeCounter;
private Task _reactorTask;
private long _bareRequest;
private readonly Dictionary<long, (Type type, TaskCompletionSource<object> tcs)> _pendingRequests = new();
private Rpc _bareRpc;
private readonly Dictionary<long, Rpc> _pendingRpcs = new();
private SemaphoreSlim _sendSemaphore = new(0);
private readonly SemaphoreSlim _semaphore = new(1);
private Task _connecting;
@ -501,20 +501,20 @@ namespace WTelegram
lock (_msgsToAck) _msgsToAck.Clear();
Reset(false, false);
_reactorReconnects = (_reactorReconnects + 1) % MaxAutoReconnects;
if (!IsMainDC && _pendingRequests.Count <= 1 && ex is ApplicationException { Message: ConnectionShutDown } or IOException { InnerException: SocketException })
if (_pendingRequests.Values.FirstOrDefault() is var (type, tcs) && (type is null || type == typeof(Pong)))
if (!IsMainDC && _pendingRpcs.Count <= 1 && ex is ApplicationException { Message: ConnectionShutDown } or IOException { InnerException: SocketException })
if (_pendingRpcs.Values.FirstOrDefault() is not Rpc rpc || rpc.type == typeof(Pong))
_reactorReconnects = 0;
if (_reactorReconnects != 0)
{
await Task.Delay(5000);
if (_networkStream == null) return; // Dispose has been called in-between
await ConnectAsync(); // start a new reactor after 5 secs
lock (_pendingRequests) // retry all pending requests
lock (_pendingRpcs) // retry all pending requests
{
foreach (var (_, tcs) in _pendingRequests.Values)
tcs.SetResult(reactorError);
_pendingRequests.Clear();
_bareRequest = 0;
foreach (var rpc in _pendingRpcs.Values)
rpc.tcs.SetResult(reactorError);
_pendingRpcs.Clear();
_bareRpc = null;
}
// TODO: implement an Updates gaps handling system? https://core.telegram.org/api/updates
if (IsMainDC)
@ -528,12 +528,12 @@ namespace WTelegram
}
catch
{
lock (_pendingRequests) // abort all pending requests
lock (_pendingRpcs) // abort all pending requests
{
foreach (var (_, tcs) in _pendingRequests.Values)
tcs.SetException(ex);
_pendingRequests.Clear();
_bareRequest = 0;
foreach (var rpc in _pendingRpcs.Values)
rpc.tcs.SetException(ex);
_pendingRpcs.Clear();
_bareRpc = null;
}
OnUpdate(reactorError);
}
@ -639,16 +639,20 @@ namespace WTelegram
};
}
private async Task<long> SendAsync(IObject msg, bool isContent)
private async Task SendAsync(IObject msg, bool isContent, Rpc rpc = null)
{
if (_dcSession.AuthKeyID != 0 && isContent && CheckMsgsToAck() is MsgsAck msgsAck)
isContent &= _dcSession.AuthKeyID != 0;
(long msgId, int seqno) = NewMsgId(isContent);
if (rpc != null)
lock (_pendingRpcs)
_pendingRpcs[rpc.msgId = msgId] = rpc;
if (isContent && CheckMsgsToAck() is MsgsAck msgsAck)
{
var ackMsg = NewMsgId(false);
var mainMsg = NewMsgId(true);
await SendAsync(MakeContainer((msgsAck, ackMsg), (msg, mainMsg)), false);
return mainMsg.msgId;
var (ackId, ackSeqno) = NewMsgId(false);
var container = new MsgContainer { messages = new _Message[] { new(msgId, seqno, msg), new(ackId, ackSeqno, msgsAck) } };
await SendAsync(container, false);
return;
}
(long msgId, int seqno) = NewMsgId(isContent && _dcSession.AuthKeyID != 0);
await _sendSemaphore.WaitAsync();
try
{
@ -714,7 +718,6 @@ namespace WTelegram
{
_sendSemaphore.Release();
}
return msgId;
}
internal MsgContainer ReadMsgContainer(TL.BinaryReader reader)
@ -723,12 +726,7 @@ namespace WTelegram
var array = new _Message[count];
for (int i = 0; i < count; i++)
{
var msg = array[i] = new _Message
{
msg_id = reader.ReadInt64(),
seqno = reader.ReadInt32(),
bytes = reader.ReadInt32(),
};
var msg = array[i] = new _Message(reader.ReadInt64(), reader.ReadInt32(), null) { bytes = reader.ReadInt32() };
if ((msg.seqno & 1) != 0) lock (_msgsToAck) _msgsToAck.Add(msg.msg_id);
var pos = reader.BaseStream.Position;
try
@ -757,14 +755,14 @@ namespace WTelegram
private RpcResult ReadRpcResult(TL.BinaryReader reader)
{
long msgId = reader.ReadInt64();
var (type, tcs) = PullPendingRequest(msgId);
var rpc = PullPendingRequest(msgId);
object result;
if (tcs != null)
if (rpc != null)
{
try
{
if (!type.IsArray)
result = reader.ReadTLValue(type);
if (!rpc.type.IsArray)
result = reader.ReadTLValue(rpc.type);
else
{
var peek = reader.ReadUInt32();
@ -772,23 +770,23 @@ namespace WTelegram
result = reader.ReadTLObject(Layer.RpcErrorCtor);
else if (peek == Layer.GZipedCtor)
using (var gzipReader = new TL.BinaryReader(new GZipStream(new MemoryStream(reader.ReadTLBytes()), CompressionMode.Decompress), reader.Client))
result = gzipReader.ReadTLValue(type);
result = gzipReader.ReadTLValue(rpc.type);
else
{
reader.BaseStream.Position -= 4;
result = reader.ReadTLValue(type);
result = reader.ReadTLValue(rpc.type);
}
}
if (type.IsEnum) result = Enum.ToObject(type, result);
if (rpc.type.IsEnum) result = Enum.ToObject(rpc.type, result);
if (result is RpcError rpcError)
Helpers.Log(4, $" → RpcError {rpcError.error_code,3} {rpcError.error_message,-24} #{(short)msgId.GetHashCode():X4}");
else
Helpers.Log(1, $" → {result?.GetType().Name,-37} #{(short)msgId.GetHashCode():X4}");
tcs.SetResult(result);
rpc.tcs.SetResult(result);
}
catch (Exception ex)
{
tcs.SetException(ex);
rpc.tcs.SetException(ex);
throw;
}
}
@ -812,24 +810,29 @@ namespace WTelegram
return new RpcResult { req_msg_id = msgId, result = result };
}
private (Type type, TaskCompletionSource<object> tcs) PullPendingRequest(long msgId)
class Rpc
{
(Type type, TaskCompletionSource<object> tcs) request;
lock (_pendingRequests)
if (_pendingRequests.TryGetValue(msgId, out request))
_pendingRequests.Remove(msgId);
public Type type;
public TaskCompletionSource<object> tcs = new(TaskCreationOptions.RunContinuationsAsynchronously);
public long msgId;
public Task<object> Task => tcs.Task;
}
private Rpc PullPendingRequest(long msgId)
{
Rpc request;
lock (_pendingRpcs)
if (_pendingRpcs.TryGetValue(msgId, out request))
_pendingRpcs.Remove(msgId);
return request;
}
internal async Task<X> InvokeBare<X>(IMethod<X> request)
{
if (_bareRequest != 0) throw new ApplicationException("A bare request is already undergoing");
var msgId = await SendAsync(request, false);
var tcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
lock (_pendingRequests)
_pendingRequests[msgId] = (typeof(X), tcs);
_bareRequest = msgId;
return (X)await tcs.Task;
if (_bareRpc != null) throw new ApplicationException("A bare request is already undergoing");
_bareRpc = new Rpc { type = typeof(X) };
await SendAsync(request, false, _bareRpc);
return (X)await _bareRpc.Task;
}
/// <summary>Call the given TL method <i>(You shouldn't need to use this method directly)</i></summary>
@ -839,12 +842,10 @@ namespace WTelegram
public async Task<X> Invoke<X>(IMethod<X> query)
{
retry:
var msgId = await SendAsync(query, true);
var tcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
lock (_pendingRequests)
_pendingRequests[msgId] = (typeof(X), tcs);
var rpc = new Rpc { type = typeof(X) };
await SendAsync(query, true, rpc);
bool got503 = false;
var result = await tcs.Task;
var result = await rpc.Task;
switch (result)
{
case X resultX: return resultX;
@ -906,17 +907,6 @@ namespace WTelegram
}
}
private static MsgContainer MakeContainer(params (IObject obj, (long msgId, int seqno))[] msgs)
=> new()
{
messages = msgs.Select(msg => new _Message
{
msg_id = msg.Item2.msgId,
seqno = msg.Item2.seqno,
body = msg.obj
}).ToArray()
};
private async Task HandleMessageAsync(IObject obj)
{
switch (obj)
@ -979,30 +969,29 @@ namespace WTelegram
}
if (retryLast)
{
var newMsgId = await SendAsync(lastSentMsg, true);
lock (_pendingRequests)
if (_pendingRequests.TryGetValue(badMsgNotification.bad_msg_id, out var t))
{
_pendingRequests.Remove(badMsgNotification.bad_msg_id);
_pendingRequests[newMsgId] = t;
}
Rpc prevRequest;
lock (_pendingRpcs)
_pendingRpcs.TryGetValue(badMsgNotification.bad_msg_id, out prevRequest);
await SendAsync(lastSentMsg, true, prevRequest);
lock (_pendingRpcs)
_pendingRpcs.Remove(badMsgNotification.bad_msg_id);
}
else if (PullPendingRequest(badMsgNotification.bad_msg_id).tcs is TaskCompletionSource<object> tcs)
else if (PullPendingRequest(badMsgNotification.bad_msg_id) is Rpc rpc)
{
if (_bareRequest == badMsgNotification.bad_msg_id) _bareRequest = 0;
tcs.SetException(new ApplicationException($"BadMsgNotification {badMsgNotification.error_code}"));
if (_bareRpc.msgId == badMsgNotification.bad_msg_id) _bareRpc = null;
rpc.tcs.SetException(new ApplicationException($"BadMsgNotification {badMsgNotification.error_code}"));
}
else
OnUpdate(obj);
break;
default:
if (_bareRequest != 0)
if (_bareRpc != null)
{
var (type, tcs) = PullPendingRequest(_bareRequest);
if (type?.IsAssignableFrom(obj.GetType()) == true)
var rpc = PullPendingRequest(_bareRpc.msgId);
if (rpc?.type.IsAssignableFrom(obj.GetType()) == true)
{
_bareRequest = 0;
tcs.SetResult(obj);
_bareRpc = null;
rpc.tcs.SetResult(obj);
break;
}
}
@ -1012,9 +1001,9 @@ namespace WTelegram
void SetResult(long msgId, object result)
{
var tcs = PullPendingRequest(msgId).tcs;
if (tcs != null)
tcs.SetResult(result);
var rpc = PullPendingRequest(msgId);
if (rpc != null)
rpc.tcs.SetResult(result);
else
OnUpdate(obj);
}

View file

@ -374,6 +374,7 @@ namespace TL
[TLDef(0x5BB8E511)] //message#5bb8e511 msg_id:long seqno:int bytes:int body:Object = Message
public class _Message
{
public _Message(long msgId, int seqNo, IObject obj) { msg_id = msgId; seqno = seqNo; body = obj; }
public long msg_id;
public int seqno;
public int bytes;