Fix issue with actual RpcResult in MsgContainer ; Parallelize upload of file parts

This commit is contained in:
Wizou 2021-08-13 07:06:44 +02:00
parent 897b61747a
commit e01caba162
6 changed files with 151 additions and 110 deletions

View file

@ -249,6 +249,43 @@ namespace WTelegram
return msgId;
}
private static async Task<int> FullReadAsync(Stream stream, byte[] buffer, int length, CancellationToken ct = default)
{
for (int offset = 0; offset != length;)
{
var read = await stream.ReadAsync(buffer.AsMemory(offset, length - offset), ct);
if (read == 0) return offset;
offset += read;
}
return length;
}
private async Task<byte[]> RecvFrameAsync(CancellationToken ct)
{
byte[] frame = new byte[8];
if (await FullReadAsync(_networkStream, frame, 8, ct) != 8)
throw new ApplicationException("Could not read frame prefix : Connection shut down");
int length = BinaryPrimitives.ReadInt32LittleEndian(frame) - 12;
if (length <= 0 || length >= 0x10000)
throw new ApplicationException("Invalid frame_len");
int seqno = BinaryPrimitives.ReadInt32LittleEndian(frame.AsSpan(4));
if (seqno != _frame_seqRx++)
{
Trace.TraceWarning($"Unexpected frame_seq received: {seqno} instead of {_frame_seqRx}");
_frame_seqRx = seqno + 1;
}
var payload = new byte[length];
if (await FullReadAsync(_networkStream, payload, length, ct) != length)
throw new ApplicationException("Could not read frame data : Connection shut down");
uint crc32 = Force.Crc32.Crc32Algorithm.Compute(frame, 0, 8);
crc32 = Force.Crc32.Crc32Algorithm.Append(crc32, payload);
if (await FullReadAsync(_networkStream, frame, 4, ct) != 4)
throw new ApplicationException("Could not read frame CRC : Connection shut down");
if (crc32 != BinaryPrimitives.ReadUInt32LittleEndian(frame))
throw new ApplicationException("Invalid envelope CRC32");
return payload;
}
internal async Task<ITLObject> RecvInternalAsync(CancellationToken ct)
{
var data = await RecvFrameAsync(ct);
@ -297,8 +334,8 @@ namespace WTelegram
{
Helpers.Log(3, $"Server salt has changed: {_session.Salt:X8} -> {serverSalt:X8}");
_session.Salt = serverSalt;
if (++_unexpectedSaltChange >= 10)
throw new ApplicationException($"Server salt changed unexpectedly more than 10 times during this run");
if (++_unexpectedSaltChange >= 30)
throw new ApplicationException($"Server salt changed unexpectedly more than 30 times during this session");
}
if (sessionId != _session.Id) throw new ApplicationException($"Unexpected session ID {_session.Id} != {_session.Id}");
if ((msgId & 1) == 0) throw new ApplicationException($"Invalid server msgId {msgId}");
@ -314,11 +351,23 @@ namespace WTelegram
if (!data.AsSpan(8, 16).SequenceEqual(_sha256.Hash.AsSpan(8, 16)))
throw new ApplicationException($"Mismatch between MsgKey & decrypted SHA1");
#endif
var obj = reader.ReadTLObject(type => type == typeof(RpcResult));
Helpers.Log(1, $"Receiving {obj.GetType().Name,-50} {_session.MsgIdToStamp(msgId):u} {((seqno & 1) != 0 ? "" : "(svc)")} {((msgId & 2) == 0 ? "" : "NAR")}");
if (obj is RpcResult rpcResult)
DeserializeRpcResult(reader, rpcResult); // necessary hack because some RPC return bare types like bool or int[]
return obj;
var ctorNb = reader.ReadUInt32();
if (ctorNb == Schema.MsgContainer)
{
Helpers.Log(1, $"Receiving {"MsgContainer",-50} {_session.MsgIdToStamp(msgId):u} (svc)");
return ReadMsgContainer(reader);
}
else if (ctorNb == Schema.RpcResult)
{
Helpers.Log(1, $"Receiving {"RpcResult",-50} {_session.MsgIdToStamp(msgId):u}");
return ReadRpcResult(reader);
}
else
{
var obj = reader.ReadTLObject(ctorNb);
Helpers.Log(1, $"Receiving {obj.GetType().Name,-50} {_session.MsgIdToStamp(msgId):u} {((seqno & 1) != 0 ? "" : "(svc)")} {((msgId & 2) == 0 ? "" : "NAR")}");
return obj;
}
}
static string TransportError(int error_code) => error_code switch
@ -329,65 +378,66 @@ namespace WTelegram
};
}
private async Task<byte[]> RecvFrameAsync(CancellationToken ct)
internal MsgContainer ReadMsgContainer(BinaryReader reader)
{
byte[] frame = new byte[8];
if (await FullReadAsync(_networkStream, frame, 8, ct) != 8)
throw new ApplicationException("Could not read frame prefix : Connection shut down");
int length = BinaryPrimitives.ReadInt32LittleEndian(frame) - 12;
if (length <= 0 || length >= 0x10000)
throw new ApplicationException("Invalid frame_len");
int seqno = BinaryPrimitives.ReadInt32LittleEndian(frame.AsSpan(4));
if (seqno != _frame_seqRx++)
int count = reader.ReadInt32();
var array = new _Message[count];
for (int i = 0; i < count; i++)
{
Trace.TraceWarning($"Unexpected frame_seq received: {seqno} instead of {_frame_seqRx}");
_frame_seqRx = seqno + 1;
var msg = array[i] = new _Message
{
msg_id = reader.ReadInt64(),
seqno = reader.ReadInt32(),
bytes = reader.ReadInt32(),
};
if ((msg.seqno & 1) != 0) lock(_msgsToAck) _msgsToAck.Add(msg.msg_id);
var pos = reader.BaseStream.Position;
try
{
var ctorNb = reader.ReadUInt32();
if (ctorNb == Schema.RpcResult)
{
Helpers.Log(1, $" → {"RpcResult",-48} {_session.MsgIdToStamp(msg.msg_id):u}");
msg.body = ReadRpcResult(reader);
}
else
{
var obj = msg.body = reader.ReadTLObject(ctorNb);
Helpers.Log(1, $" → {obj.GetType().Name,-48} {_session.MsgIdToStamp(msg.msg_id):u} {((msg.seqno & 1) != 0 ? "" : "(svc)")} {((msg.msg_id & 2) == 0 ? "" : "NAR")}");
}
}
catch (Exception ex)
{
Helpers.Log(4, "While deserializing vector<%Message>: " + ex.ToString());
}
reader.BaseStream.Position = pos + array[i].bytes;
}
var payload = new byte[length];
if (await FullReadAsync(_networkStream, payload, length, ct) != length)
throw new ApplicationException("Could not read frame data : Connection shut down");
uint crc32 = Force.Crc32.Crc32Algorithm.Compute(frame, 0, 8);
crc32 = Force.Crc32.Crc32Algorithm.Append(crc32, payload);
if (await FullReadAsync(_networkStream, frame, 4, ct) != 4)
throw new ApplicationException("Could not read frame CRC : Connection shut down");
if (crc32 != BinaryPrimitives.ReadUInt32LittleEndian(frame))
throw new ApplicationException("Invalid envelope CRC32");
return payload;
return new MsgContainer { messages = array };
}
private static async Task<int> FullReadAsync(Stream stream, byte[] buffer, int length, CancellationToken ct = default)
private RpcResult ReadRpcResult(BinaryReader reader)
{
for (int offset = 0; offset != length;)
{
var read = await stream.ReadAsync(buffer.AsMemory(offset, length - offset), ct);
if (read == 0) return offset;
offset += read;
}
return length;
}
private bool DeserializeRpcResult(BinaryReader reader, RpcResult rpcResult)
{
var msgId = rpcResult.req_msg_id = reader.ReadInt64();
long msgId = reader.ReadInt64();
(Type type, TaskCompletionSource<object> tcs) request;
lock (_pendingRequests)
if (_pendingRequests.TryGetValue(msgId, out request))
_pendingRequests.Remove(msgId);
if (request.type != null)
{
rpcResult.result = reader.ReadTLValue(request.type);
Helpers.Log(1, $" result → {request.type.Name,-48} #{(short)msgId.GetHashCode():X4}");
Task.Run(() => request.tcs.SetResult(rpcResult.result)); // to avoid deadlock, see https://blog.stephencleary.com/2012/12/dont-block-in-asynchronous-code.html
var result = reader.ReadTLValue(request.type);
Helpers.Log(1, $" → {result?.GetType().Name,-47} #{(short)msgId.GetHashCode():X4}");
Task.Run(() => request.tcs.SetResult(result)); // to avoid deadlock, see https://blog.stephencleary.com/2012/12/dont-block-in-asynchronous-code.html
return new RpcResult { req_msg_id = msgId, result = result };
}
else
{
rpcResult.result = reader.ReadTLObject();
var result = reader.ReadTLObject();
if (_session.MsgIdToStamp(msgId) >= _session.SessionStart)
Helpers.Log(4, $" result → {rpcResult.result?.GetType().Name,-48}) for unknown msgId #{(short)msgId.GetHashCode():X4}");
Helpers.Log(4, $" → {result?.GetType().Name,-47} for unknown msgId #{(short)msgId.GetHashCode():X4}");
else
Helpers.Log(1, $" result → {rpcResult.result?.GetType().Name,-48} for past msgId #{(short)msgId.GetHashCode():X4}");
Helpers.Log(1, $" → {result?.GetType().Name,-47} for past msgId #{(short)msgId.GetHashCode():X4}");
return new RpcResult { req_msg_id = msgId, result = result };
}
return true;
}
public class RpcException : Exception
@ -481,13 +531,8 @@ namespace WTelegram
{
case MsgContainer container:
foreach (var msg in container.messages)
{
var typeName = msg.body?.GetType().Name;
if (typeName == "RpcResult") typeName += $" ({((RpcResult)msg.body).result.GetType().Name})";
Helpers.Log(1, $" → {typeName,-48} {_session.MsgIdToStamp(msg.msg_id):u} {((msg.seqno & 1) != 0 ? "" : "(svc)")} {((msg.msg_id & 2) == 0 ? "" : "NAR")}");
if ((msg.seqno & 1) != 0) lock (_msgsToAck) _msgsToAck.Add(msg.msg_id);
if (msg.body != null) await HandleMessageAsync(msg.body);
}
if (msg.body != null)
await HandleMessageAsync(msg.body);
break;
case BadServerSalt badServerSalt:
_session.Salt = badServerSalt.new_server_salt;
@ -506,8 +551,7 @@ namespace WTelegram
Helpers.Log(3, $"BadMsgNotification {badMsgNotification.error_code} for msg #{(short)badMsgNotification.bad_msg_id.GetHashCode():X4}");
break;
case RpcResult rpcResult:
//if (_session.MsgIdToStamp(rpcResult.req_msg_id) >= _session.SessionStart)
break; // tcs wake up was already done in DeserializeRpcResult
break; // wake-up of waiting task was already done in ReadRpcResult
default:
if (_rawRequest != null)
{
@ -575,23 +619,47 @@ namespace WTelegram
const int partSize = 512 * 1024;
int file_total_parts = (int)((length - 1) / partSize) + 1;
long file_id = Helpers.RandomLong();
var bytes = new byte[Math.Min(partSize, length)];
int file_part = 0, read;
for (long bytesLeft = length; bytesLeft != 0; file_part++)
const int ParallelSends = 10;
var semaphore = new SemaphoreSlim(ParallelSends);
var tasks = new Dictionary<int, Task>();
bool abort = false;
for (long bytesLeft = length; !abort && bytesLeft != 0; file_part++)
{
//TODO: parallelize several parts sending through a N-semaphore? (needs a reactor first)
read = await FullReadAsync(stream, bytes, (int)Math.Min(partSize, bytesLeft));
if (isBig)
await Upload_SaveBigFilePart(file_id, file_part, file_total_parts, bytes);
else
{
await Upload_SaveFilePart(file_id, file_part, bytes);
var bytes = new byte[Math.Min(partSize, bytesLeft)];
read = await FullReadAsync(stream, bytes, bytes.Length);
await semaphore.WaitAsync();
var task = SavePart(file_part, bytes);
lock (tasks) tasks[file_part] = task;
if (!isBig)
md5.TransformBlock(bytes, 0, read, null, 0);
}
bytesLeft -= read;
if (read < partSize && bytesLeft != 0) throw new ApplicationException($"Failed to fully read stream ({read},{bytesLeft})");
async Task SavePart(int file_part, byte[] bytes)
{
try
{
if (isBig)
await Upload_SaveBigFilePart(file_id, file_part, file_total_parts, bytes);
else
await Upload_SaveFilePart(file_id, file_part, bytes);
lock (tasks) tasks.Remove(file_part);
}
catch (Exception)
{
abort = true;
}
finally
{
semaphore.Release();
}
}
}
if (!isBig) md5.TransformFinalBlock(bytes, 0, 0);
for (int i = 0; i < ParallelSends; i++)
await semaphore.WaitAsync(); // wait for all the remaining parts to be sent
await Task.WhenAll(tasks.Values); // propagate any task exception (tasks should be empty on success)
if (!isBig) md5.TransformFinalBlock(Array.Empty<byte>(), 0, 0);
return isBig ? new InputFileBig { id = file_id, parts = file_total_parts, name = filename }
: new InputFile { id = file_id, parts = file_total_parts, name = filename, md5_checksum = md5.Hash };
}

View file

@ -7,8 +7,10 @@ namespace TL
static partial class Schema
{
public const int Layer = 121; // fetched 10/08/2021 11:46:24
public const int VectorCtor = 0x1CB5C415;
public const int NullCtor = 0x56730BCC;
public const uint VectorCtor = 0x1CB5C415;
public const uint NullCtor = 0x56730BCC;
public const uint RpcResult = 0xF35C6D01;
public const uint MsgContainer = 0x73F1F8DC;
internal readonly static Dictionary<uint, Type> Table = new()
{

View file

@ -51,14 +51,13 @@ namespace TL
}
}
internal static ITLObject ReadTLObject(this BinaryReader reader, Func<Type, bool> notifyType = null)
internal static ITLObject ReadTLObject(this BinaryReader reader, uint ctorNb = 0)
{
var ctorNb = reader.ReadUInt32();
if (ctorNb == 0) ctorNb = reader.ReadUInt32();
if (ctorNb == NullCtor) return null;
if (!Table.TryGetValue(ctorNb, out var type))
throw new ApplicationException($"Cannot find type for ctor #{ctorNb:x}");
var obj = Activator.CreateInstance(type);
if (notifyType?.Invoke(type) == true) return (ITLObject) obj;
var fields = obj.GetType().GetFields().GroupBy(f => f.DeclaringType).Reverse().SelectMany(g => g);
int flags = 0;
IfFlagAttribute ifFlag;
@ -129,8 +128,6 @@ namespace TL
{
if (type == typeof(byte[]))
return reader.ReadTLBytes();
else if (type == typeof(_Message[]))
return reader.ReadTLMessages();
else
return reader.ReadTLVector(type);
}
@ -139,7 +136,7 @@ namespace TL
else if (type == typeof(Int256))
return new Int256(reader);
else
return ReadTLObject(reader);
return reader.ReadTLObject();
default:
ShouldntBeHere();
return null;
@ -158,7 +155,7 @@ namespace TL
internal static Array ReadTLVector(this BinaryReader reader, Type type)
{
var ctorNb = reader.ReadInt32();
var ctorNb = reader.ReadUInt32();
if (ctorNb != VectorCtor) throw new ApplicationException($"Cannot deserialize {type.Name} with ctor #{ctorNb:x}");
var elementType = type.GetElementType();
int count = reader.ReadInt32();
@ -225,36 +222,10 @@ namespace TL
writer.Write(0); // null arrays are serialized as empty
}
internal static _Message[] ReadTLMessages(this BinaryReader reader)
{
int count = reader.ReadInt32();
var array = new _Message[count];
for (int i = 0; i < count; i++)
{
array[i] = new _Message
{
msg_id = reader.ReadInt64(),
seqno = reader.ReadInt32(),
bytes = reader.ReadInt32(),
};
var pos = reader.BaseStream.Position;
try
{
array[i].body = reader.ReadTLObject();
}
catch (Exception ex)
{
Helpers.Log(4, "While deserializing vector<%Message>: " + ex.ToString());
}
reader.BaseStream.Position = pos + array[i].bytes;
}
return array;
}
internal static ITLObject UnzipPacket(GzipPacked obj)
{
using var reader = new BinaryReader(new GZipStream(new MemoryStream(obj.packed_data), CompressionMode.Decompress));
var result = ReadTLObject(reader);
var result = reader.ReadTLObject();
return result;
}