mirror of
https://github.com/wiz0u/WTelegramClient.git
synced 2026-01-29 11:44:19 +01:00
Fix race condition on pendingRpcs adding/pulling
This commit is contained in:
parent
a54cc92618
commit
5c5b8032b9
155
src/Client.cs
155
src/Client.cs
|
|
@ -65,8 +65,8 @@ namespace WTelegram
|
|||
private readonly Random _random = new();
|
||||
private int _saltChangeCounter;
|
||||
private Task _reactorTask;
|
||||
private long _bareRequest;
|
||||
private readonly Dictionary<long, (Type type, TaskCompletionSource<object> tcs)> _pendingRequests = new();
|
||||
private Rpc _bareRpc;
|
||||
private readonly Dictionary<long, Rpc> _pendingRpcs = new();
|
||||
private SemaphoreSlim _sendSemaphore = new(0);
|
||||
private readonly SemaphoreSlim _semaphore = new(1);
|
||||
private Task _connecting;
|
||||
|
|
@ -501,20 +501,20 @@ namespace WTelegram
|
|||
lock (_msgsToAck) _msgsToAck.Clear();
|
||||
Reset(false, false);
|
||||
_reactorReconnects = (_reactorReconnects + 1) % MaxAutoReconnects;
|
||||
if (!IsMainDC && _pendingRequests.Count <= 1 && ex is ApplicationException { Message: ConnectionShutDown } or IOException { InnerException: SocketException })
|
||||
if (_pendingRequests.Values.FirstOrDefault() is var (type, tcs) && (type is null || type == typeof(Pong)))
|
||||
if (!IsMainDC && _pendingRpcs.Count <= 1 && ex is ApplicationException { Message: ConnectionShutDown } or IOException { InnerException: SocketException })
|
||||
if (_pendingRpcs.Values.FirstOrDefault() is not Rpc rpc || rpc.type == typeof(Pong))
|
||||
_reactorReconnects = 0;
|
||||
if (_reactorReconnects != 0)
|
||||
{
|
||||
await Task.Delay(5000);
|
||||
if (_networkStream == null) return; // Dispose has been called in-between
|
||||
await ConnectAsync(); // start a new reactor after 5 secs
|
||||
lock (_pendingRequests) // retry all pending requests
|
||||
lock (_pendingRpcs) // retry all pending requests
|
||||
{
|
||||
foreach (var (_, tcs) in _pendingRequests.Values)
|
||||
tcs.SetResult(reactorError);
|
||||
_pendingRequests.Clear();
|
||||
_bareRequest = 0;
|
||||
foreach (var rpc in _pendingRpcs.Values)
|
||||
rpc.tcs.SetResult(reactorError);
|
||||
_pendingRpcs.Clear();
|
||||
_bareRpc = null;
|
||||
}
|
||||
// TODO: implement an Updates gaps handling system? https://core.telegram.org/api/updates
|
||||
if (IsMainDC)
|
||||
|
|
@ -528,12 +528,12 @@ namespace WTelegram
|
|||
}
|
||||
catch
|
||||
{
|
||||
lock (_pendingRequests) // abort all pending requests
|
||||
lock (_pendingRpcs) // abort all pending requests
|
||||
{
|
||||
foreach (var (_, tcs) in _pendingRequests.Values)
|
||||
tcs.SetException(ex);
|
||||
_pendingRequests.Clear();
|
||||
_bareRequest = 0;
|
||||
foreach (var rpc in _pendingRpcs.Values)
|
||||
rpc.tcs.SetException(ex);
|
||||
_pendingRpcs.Clear();
|
||||
_bareRpc = null;
|
||||
}
|
||||
OnUpdate(reactorError);
|
||||
}
|
||||
|
|
@ -639,16 +639,20 @@ namespace WTelegram
|
|||
};
|
||||
}
|
||||
|
||||
private async Task<long> SendAsync(IObject msg, bool isContent)
|
||||
private async Task SendAsync(IObject msg, bool isContent, Rpc rpc = null)
|
||||
{
|
||||
if (_dcSession.AuthKeyID != 0 && isContent && CheckMsgsToAck() is MsgsAck msgsAck)
|
||||
isContent &= _dcSession.AuthKeyID != 0;
|
||||
(long msgId, int seqno) = NewMsgId(isContent);
|
||||
if (rpc != null)
|
||||
lock (_pendingRpcs)
|
||||
_pendingRpcs[rpc.msgId = msgId] = rpc;
|
||||
if (isContent && CheckMsgsToAck() is MsgsAck msgsAck)
|
||||
{
|
||||
var ackMsg = NewMsgId(false);
|
||||
var mainMsg = NewMsgId(true);
|
||||
await SendAsync(MakeContainer((msgsAck, ackMsg), (msg, mainMsg)), false);
|
||||
return mainMsg.msgId;
|
||||
var (ackId, ackSeqno) = NewMsgId(false);
|
||||
var container = new MsgContainer { messages = new _Message[] { new(msgId, seqno, msg), new(ackId, ackSeqno, msgsAck) } };
|
||||
await SendAsync(container, false);
|
||||
return;
|
||||
}
|
||||
(long msgId, int seqno) = NewMsgId(isContent && _dcSession.AuthKeyID != 0);
|
||||
await _sendSemaphore.WaitAsync();
|
||||
try
|
||||
{
|
||||
|
|
@ -714,7 +718,6 @@ namespace WTelegram
|
|||
{
|
||||
_sendSemaphore.Release();
|
||||
}
|
||||
return msgId;
|
||||
}
|
||||
|
||||
internal MsgContainer ReadMsgContainer(TL.BinaryReader reader)
|
||||
|
|
@ -723,12 +726,7 @@ namespace WTelegram
|
|||
var array = new _Message[count];
|
||||
for (int i = 0; i < count; i++)
|
||||
{
|
||||
var msg = array[i] = new _Message
|
||||
{
|
||||
msg_id = reader.ReadInt64(),
|
||||
seqno = reader.ReadInt32(),
|
||||
bytes = reader.ReadInt32(),
|
||||
};
|
||||
var msg = array[i] = new _Message(reader.ReadInt64(), reader.ReadInt32(), null) { bytes = reader.ReadInt32() };
|
||||
if ((msg.seqno & 1) != 0) lock (_msgsToAck) _msgsToAck.Add(msg.msg_id);
|
||||
var pos = reader.BaseStream.Position;
|
||||
try
|
||||
|
|
@ -757,14 +755,14 @@ namespace WTelegram
|
|||
private RpcResult ReadRpcResult(TL.BinaryReader reader)
|
||||
{
|
||||
long msgId = reader.ReadInt64();
|
||||
var (type, tcs) = PullPendingRequest(msgId);
|
||||
var rpc = PullPendingRequest(msgId);
|
||||
object result;
|
||||
if (tcs != null)
|
||||
if (rpc != null)
|
||||
{
|
||||
try
|
||||
{
|
||||
if (!type.IsArray)
|
||||
result = reader.ReadTLValue(type);
|
||||
if (!rpc.type.IsArray)
|
||||
result = reader.ReadTLValue(rpc.type);
|
||||
else
|
||||
{
|
||||
var peek = reader.ReadUInt32();
|
||||
|
|
@ -772,23 +770,23 @@ namespace WTelegram
|
|||
result = reader.ReadTLObject(Layer.RpcErrorCtor);
|
||||
else if (peek == Layer.GZipedCtor)
|
||||
using (var gzipReader = new TL.BinaryReader(new GZipStream(new MemoryStream(reader.ReadTLBytes()), CompressionMode.Decompress), reader.Client))
|
||||
result = gzipReader.ReadTLValue(type);
|
||||
result = gzipReader.ReadTLValue(rpc.type);
|
||||
else
|
||||
{
|
||||
reader.BaseStream.Position -= 4;
|
||||
result = reader.ReadTLValue(type);
|
||||
result = reader.ReadTLValue(rpc.type);
|
||||
}
|
||||
}
|
||||
if (type.IsEnum) result = Enum.ToObject(type, result);
|
||||
if (rpc.type.IsEnum) result = Enum.ToObject(rpc.type, result);
|
||||
if (result is RpcError rpcError)
|
||||
Helpers.Log(4, $" → RpcError {rpcError.error_code,3} {rpcError.error_message,-24} #{(short)msgId.GetHashCode():X4}");
|
||||
else
|
||||
Helpers.Log(1, $" → {result?.GetType().Name,-37} #{(short)msgId.GetHashCode():X4}");
|
||||
tcs.SetResult(result);
|
||||
rpc.tcs.SetResult(result);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
tcs.SetException(ex);
|
||||
rpc.tcs.SetException(ex);
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
|
@ -812,24 +810,29 @@ namespace WTelegram
|
|||
return new RpcResult { req_msg_id = msgId, result = result };
|
||||
}
|
||||
|
||||
private (Type type, TaskCompletionSource<object> tcs) PullPendingRequest(long msgId)
|
||||
class Rpc
|
||||
{
|
||||
(Type type, TaskCompletionSource<object> tcs) request;
|
||||
lock (_pendingRequests)
|
||||
if (_pendingRequests.TryGetValue(msgId, out request))
|
||||
_pendingRequests.Remove(msgId);
|
||||
public Type type;
|
||||
public TaskCompletionSource<object> tcs = new(TaskCreationOptions.RunContinuationsAsynchronously);
|
||||
public long msgId;
|
||||
public Task<object> Task => tcs.Task;
|
||||
}
|
||||
|
||||
private Rpc PullPendingRequest(long msgId)
|
||||
{
|
||||
Rpc request;
|
||||
lock (_pendingRpcs)
|
||||
if (_pendingRpcs.TryGetValue(msgId, out request))
|
||||
_pendingRpcs.Remove(msgId);
|
||||
return request;
|
||||
}
|
||||
|
||||
internal async Task<X> InvokeBare<X>(IMethod<X> request)
|
||||
{
|
||||
if (_bareRequest != 0) throw new ApplicationException("A bare request is already undergoing");
|
||||
var msgId = await SendAsync(request, false);
|
||||
var tcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
|
||||
lock (_pendingRequests)
|
||||
_pendingRequests[msgId] = (typeof(X), tcs);
|
||||
_bareRequest = msgId;
|
||||
return (X)await tcs.Task;
|
||||
if (_bareRpc != null) throw new ApplicationException("A bare request is already undergoing");
|
||||
_bareRpc = new Rpc { type = typeof(X) };
|
||||
await SendAsync(request, false, _bareRpc);
|
||||
return (X)await _bareRpc.Task;
|
||||
}
|
||||
|
||||
/// <summary>Call the given TL method <i>(You shouldn't need to use this method directly)</i></summary>
|
||||
|
|
@ -839,12 +842,10 @@ namespace WTelegram
|
|||
public async Task<X> Invoke<X>(IMethod<X> query)
|
||||
{
|
||||
retry:
|
||||
var msgId = await SendAsync(query, true);
|
||||
var tcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
|
||||
lock (_pendingRequests)
|
||||
_pendingRequests[msgId] = (typeof(X), tcs);
|
||||
var rpc = new Rpc { type = typeof(X) };
|
||||
await SendAsync(query, true, rpc);
|
||||
bool got503 = false;
|
||||
var result = await tcs.Task;
|
||||
var result = await rpc.Task;
|
||||
switch (result)
|
||||
{
|
||||
case X resultX: return resultX;
|
||||
|
|
@ -906,17 +907,6 @@ namespace WTelegram
|
|||
}
|
||||
}
|
||||
|
||||
private static MsgContainer MakeContainer(params (IObject obj, (long msgId, int seqno))[] msgs)
|
||||
=> new()
|
||||
{
|
||||
messages = msgs.Select(msg => new _Message
|
||||
{
|
||||
msg_id = msg.Item2.msgId,
|
||||
seqno = msg.Item2.seqno,
|
||||
body = msg.obj
|
||||
}).ToArray()
|
||||
};
|
||||
|
||||
private async Task HandleMessageAsync(IObject obj)
|
||||
{
|
||||
switch (obj)
|
||||
|
|
@ -979,30 +969,29 @@ namespace WTelegram
|
|||
}
|
||||
if (retryLast)
|
||||
{
|
||||
var newMsgId = await SendAsync(lastSentMsg, true);
|
||||
lock (_pendingRequests)
|
||||
if (_pendingRequests.TryGetValue(badMsgNotification.bad_msg_id, out var t))
|
||||
{
|
||||
_pendingRequests.Remove(badMsgNotification.bad_msg_id);
|
||||
_pendingRequests[newMsgId] = t;
|
||||
}
|
||||
Rpc prevRequest;
|
||||
lock (_pendingRpcs)
|
||||
_pendingRpcs.TryGetValue(badMsgNotification.bad_msg_id, out prevRequest);
|
||||
await SendAsync(lastSentMsg, true, prevRequest);
|
||||
lock (_pendingRpcs)
|
||||
_pendingRpcs.Remove(badMsgNotification.bad_msg_id);
|
||||
}
|
||||
else if (PullPendingRequest(badMsgNotification.bad_msg_id).tcs is TaskCompletionSource<object> tcs)
|
||||
else if (PullPendingRequest(badMsgNotification.bad_msg_id) is Rpc rpc)
|
||||
{
|
||||
if (_bareRequest == badMsgNotification.bad_msg_id) _bareRequest = 0;
|
||||
tcs.SetException(new ApplicationException($"BadMsgNotification {badMsgNotification.error_code}"));
|
||||
if (_bareRpc.msgId == badMsgNotification.bad_msg_id) _bareRpc = null;
|
||||
rpc.tcs.SetException(new ApplicationException($"BadMsgNotification {badMsgNotification.error_code}"));
|
||||
}
|
||||
else
|
||||
OnUpdate(obj);
|
||||
break;
|
||||
default:
|
||||
if (_bareRequest != 0)
|
||||
if (_bareRpc != null)
|
||||
{
|
||||
var (type, tcs) = PullPendingRequest(_bareRequest);
|
||||
if (type?.IsAssignableFrom(obj.GetType()) == true)
|
||||
var rpc = PullPendingRequest(_bareRpc.msgId);
|
||||
if (rpc?.type.IsAssignableFrom(obj.GetType()) == true)
|
||||
{
|
||||
_bareRequest = 0;
|
||||
tcs.SetResult(obj);
|
||||
_bareRpc = null;
|
||||
rpc.tcs.SetResult(obj);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
@ -1012,9 +1001,9 @@ namespace WTelegram
|
|||
|
||||
void SetResult(long msgId, object result)
|
||||
{
|
||||
var tcs = PullPendingRequest(msgId).tcs;
|
||||
if (tcs != null)
|
||||
tcs.SetResult(result);
|
||||
var rpc = PullPendingRequest(msgId);
|
||||
if (rpc != null)
|
||||
rpc.tcs.SetResult(result);
|
||||
else
|
||||
OnUpdate(obj);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -374,6 +374,7 @@ namespace TL
|
|||
[TLDef(0x5BB8E511)] //message#5bb8e511 msg_id:long seqno:int bytes:int body:Object = Message
|
||||
public class _Message
|
||||
{
|
||||
public _Message(long msgId, int seqNo, IObject obj) { msg_id = msgId; seqno = seqNo; body = obj; }
|
||||
public long msg_id;
|
||||
public int seqno;
|
||||
public int bytes;
|
||||
|
|
|
|||
Loading…
Reference in a new issue