diff --git a/src/Client.cs b/src/Client.cs index f74fe93..25a65cd 100644 --- a/src/Client.cs +++ b/src/Client.cs @@ -49,6 +49,8 @@ namespace WTelegram private Task _connecting; private CancellationTokenSource _cts; private int _reactorReconnects = 0; + private const int FilePartSize = 512 * 1024; + private readonly SemaphoreSlim _parallelTransfers = new(10); // max parallel part uploads/downloads #if MTPROTO1 private readonly SHA1 _sha1 = SHA1.Create(); private readonly SHA1 _sha1Recv = SHA1.Create(); @@ -161,50 +163,59 @@ namespace WTelegram { var endpoint = _dcSession?.EndPoint ?? Compat.IPEndPoint_Parse(Config("server_address")); Helpers.Log(2, $"Connecting to {endpoint}..."); - _tcpClient = new TcpClient(AddressFamily.InterNetworkV6) { Client = { DualMode = true } }; // this allows both IPv4 & IPv6 + var tcpClient = new TcpClient(AddressFamily.InterNetworkV6) { Client = { DualMode = true } }; // this allows both IPv4 & IPv6 try { - await _tcpClient.ConnectAsync(endpoint.Address, endpoint.Port); - } - catch (SocketException ex) // cannot connect to target endpoint, try to find an alternate - { - Helpers.Log(4, $"SocketException {ex.SocketErrorCode} ({ex.ErrorCode}): {ex.Message}"); - if (_dcSession?.DataCenter == null) throw; - var triedEndpoints = new HashSet { endpoint }; - if (_session.DcOptions != null) + try { - var altOptions = _session.DcOptions.Where(dco => dco.id == _dcSession.DataCenter.id && dco.flags != _dcSession.DataCenter.flags - && (dco.flags & (DcOption.Flags.cdn | DcOption.Flags.tcpo_only | DcOption.Flags.media_only)) == 0) - .OrderBy(dco => dco.flags); - // try alternate addresses for this DC - foreach (var dcOption in altOptions) + await tcpClient.ConnectAsync(endpoint.Address, endpoint.Port); + } + catch (SocketException ex) // cannot connect to target endpoint, try to find an alternate + { + Helpers.Log(4, $"SocketException {ex.SocketErrorCode} ({ex.ErrorCode}): {ex.Message}"); + if (_dcSession?.DataCenter == null) throw; + var triedEndpoints = new HashSet { endpoint }; + if (_session.DcOptions != null) { - endpoint = new(IPAddress.Parse(dcOption.ip_address), dcOption.port); - if (!triedEndpoints.Add(endpoint)) continue; - Helpers.Log(2, $"Connecting to {endpoint}..."); - try + var altOptions = _session.DcOptions.Where(dco => dco.id == _dcSession.DataCenter.id && dco.flags != _dcSession.DataCenter.flags + && (dco.flags & (DcOption.Flags.cdn | DcOption.Flags.tcpo_only | DcOption.Flags.media_only)) == 0) + .OrderBy(dco => dco.flags); + // try alternate addresses for this DC + foreach (var dcOption in altOptions) { - await _tcpClient.ConnectAsync(endpoint.Address, endpoint.Port); - _dcSession.DataCenter = dcOption; - break; + endpoint = new(IPAddress.Parse(dcOption.ip_address), dcOption.port); + if (!triedEndpoints.Add(endpoint)) continue; + Helpers.Log(2, $"Connecting to {endpoint}..."); + try + { + await tcpClient.ConnectAsync(endpoint.Address, endpoint.Port); + _dcSession.DataCenter = dcOption; + break; + } + catch (SocketException) { } } - catch (SocketException) { } + } + if (!tcpClient.Connected) + { + endpoint = Compat.IPEndPoint_Parse(Config("server_address")); // re-ask callback for an address + if (!triedEndpoints.Add(endpoint)) throw; + _dcSession.Client = null; + // is it address for a known DCSession? + _dcSession = _session.DCSessions.Values.FirstOrDefault(dcs => dcs.EndPoint.Equals(endpoint)); + _dcSession ??= new() { Id = Helpers.RandomLong() }; + _dcSession.Client = this; + Helpers.Log(2, $"Connecting to {endpoint}..."); + await tcpClient.ConnectAsync(endpoint.Address, endpoint.Port); } } - if (!_tcpClient.Connected) - { - endpoint = Compat.IPEndPoint_Parse(Config("server_address")); // re-ask callback for an address - if (!triedEndpoints.Add(endpoint)) throw; - _dcSession.Client = null; - // is it address for a known DCSession? - _dcSession = _session.DCSessions.Values.FirstOrDefault(dcs => dcs.EndPoint.Equals(endpoint)); - _dcSession ??= new() { Id = Helpers.RandomLong() }; - _dcSession.Client = this; - Helpers.Log(2, $"Connecting to {endpoint}..."); - await _tcpClient.ConnectAsync(endpoint.Address, endpoint.Port); - } } - _networkStream = _tcpClient.GetStream(); + catch (Exception) + { + tcpClient.Dispose(); + throw; + } + _tcpClient = tcpClient; + _networkStream = tcpClient.GetStream(); await _networkStream.WriteAsync(IntermediateHeader, 0, 4); _cts = new(); _saltChangeCounter = 0; @@ -977,7 +988,6 @@ namespace WTelegram } if (authorization is not Auth_Authorization { user: User user }) throw new ApplicationException("Failed to get Authorization: " + authorization.GetType().Name); - //TODO: find better serialization for User not subject to TL changes? _session.User = user; _dcSession.UserId = user.id; _session.Save(); @@ -997,25 +1007,22 @@ namespace WTelegram { long length = stream.Length; var isBig = length >= 10 * 1024 * 1024; - const int partSize = 512 * 1024; - int file_total_parts = (int)((length - 1) / partSize) + 1; + int file_total_parts = (int)((length - 1) / FilePartSize) + 1; long file_id = Helpers.RandomLong(); int file_part = 0, read; - const int ParallelSends = 10; - var semaphore = new SemaphoreSlim(ParallelSends); var tasks = new Dictionary(); bool abort = false; for (long bytesLeft = length; !abort && bytesLeft != 0; file_part++) { - var bytes = new byte[Math.Min(partSize, bytesLeft)]; + var bytes = new byte[Math.Min(FilePartSize, bytesLeft)]; read = await FullReadAsync(stream, bytes, bytes.Length); - await semaphore.WaitAsync(); + await _parallelTransfers.WaitAsync(); var task = SavePart(file_part, bytes); lock (tasks) tasks[file_part] = task; if (!isBig) md5.TransformBlock(bytes, 0, read, null, 0); bytesLeft -= read; - if (read < partSize && bytesLeft != 0) throw new ApplicationException($"Failed to fully read stream ({read},{bytesLeft})"); + if (read < FilePartSize && bytesLeft != 0) throw new ApplicationException($"Failed to fully read stream ({read},{bytesLeft})"); async Task SavePart(int file_part, byte[] bytes) { @@ -1030,16 +1037,17 @@ namespace WTelegram catch (Exception) { abort = true; + throw; } finally { - semaphore.Release(); + _parallelTransfers.Release(); } } } - for (int i = 0; i < ParallelSends; i++) - await semaphore.WaitAsync(); // wait for all the remaining parts to be sent - await Task.WhenAll(tasks.Values); // propagate any task exception (tasks should be empty on success) + Task[] remainingTasks; + lock (tasks) remainingTasks = tasks.Values.ToArray(); + await Task.WhenAll(remainingTasks); // wait completion and eventually propagate any task exception if (!isBig) md5.TransformFinalBlock(Array.Empty(), 0, 0); return isBig ? new InputFileBig { id = file_id, parts = file_total_parts, name = filename } : new InputFile { id = file_id, parts = file_total_parts, name = filename, md5_checksum = md5.Hash }; @@ -1093,8 +1101,9 @@ namespace WTelegram /// if unspecified, will download the largest version of the photo public async Task DownloadFileAsync(Photo photo, Stream outputStream, PhotoSizeBase photoSize = null) { - var fileLocation = photo.ToFileLocation(photoSize ?? photo.LargestPhotoSize); - return await DownloadFileAsync(fileLocation, outputStream, photo.dc_id); + photoSize ??= photo.LargestPhotoSize; + var fileLocation = photo.ToFileLocation(photoSize); + return await DownloadFileAsync(fileLocation, outputStream, photo.dc_id, photoSize.FileSize); } /// Download given document from Telegram into the outputStream @@ -1104,7 +1113,7 @@ namespace WTelegram public async Task DownloadFileAsync(Document document, Stream outputStream, PhotoSizeBase thumbSize = null) { var fileLocation = document.ToFileLocation(thumbSize); - var fileType = await DownloadFileAsync(fileLocation, outputStream, document.dc_id); + var fileType = await DownloadFileAsync(fileLocation, outputStream, document.dc_id, thumbSize?.FileSize ?? document.size); return thumbSize == null ? document.mime_type : "image/" + fileType; } @@ -1112,34 +1121,95 @@ namespace WTelegram /// Telegram file identifier, typically obtained with a .ToFileLocation() call /// stream to write to. This method does not close/dispose the stream /// (optional) DC on which the file is stored - public async Task DownloadFileAsync(InputFileLocationBase fileLocation, Stream outputStream, int fileDC = 0) + /// (optional) expected file size + public async Task DownloadFileAsync(InputFileLocationBase fileLocation, Stream outputStream, int fileDC = 0, int fileSize = 0) { - const int ChunkSize = 128 * 1024; - int fileSize = 0; - Upload_File fileData; + Storage_FileType fileType = Storage_FileType.unknown; var client = fileDC == 0 ? this : await GetClientForDC(fileDC, true); - do + using var writeSem = new SemaphoreSlim(1); + long streamStartPos = outputStream.Position; + int fileOffset = 0, maxOffsetSeen = 0; + var tasks = new Dictionary(); + bool abort = false; + while (!abort) { - Upload_FileBase fileBase; - try + await _parallelTransfers.WaitAsync(); + var task = LoadPart(fileOffset); + lock (tasks) tasks[fileOffset] = task; + if (fileDC == 0) { await task; fileDC = client._dcSession.DcID; } + fileOffset += FilePartSize; + if (fileSize != 0 && fileOffset >= fileSize) { - // TODO: speed-up download with multiple parallel getFile (share 10-parallel semaphore with upload) - fileBase = await client.Upload_GetFile(fileLocation, fileSize, ChunkSize); + if (await task != ((fileSize - 1) % FilePartSize) + 1) + throw new ApplicationException("Downloaded file size does not match expected file size"); + break; } - catch (RpcException ex) when (ex.Code == 303 && ex.Message.StartsWith("FILE_MIGRATE_")) + + async Task LoadPart(int offset) { - var dcId = int.Parse(ex.Message[13..]); - client = await GetClientForDC(dcId, true); - fileBase = await client.Upload_GetFile(fileLocation, fileSize, ChunkSize); + Upload_FileBase fileBase; + try + { + Console.WriteLine($"LoadPart {offset}"); + fileBase = await client.Upload_GetFile(fileLocation, offset, FilePartSize); + } + catch (RpcException ex) when (ex.Code == 303 && ex.Message.StartsWith("FILE_MIGRATE_")) + { + var dcId = int.Parse(ex.Message[13..]); + client = await GetClientForDC(dcId, true); + fileBase = await client.Upload_GetFile(fileLocation, offset, FilePartSize); + } + catch (RpcException ex) when (ex.Code == 400 && ex.Message == "OFFSET_INVALID") + { + abort = true; + return 0; + } + catch (Exception) + { + abort = true; + throw; + } + finally + { + _parallelTransfers.Release(); + } + if (fileBase is not Upload_File fileData) + throw new ApplicationException("Upload_GetFile returned unsupported " + fileBase.GetType().Name); + if (fileData.bytes.Length != FilePartSize) abort = true; + if (fileData.bytes.Length != 0) + { + fileType = fileData.type; + await writeSem.WaitAsync(); + try + { + if (streamStartPos + offset != outputStream.Position) // if we're about to write out of order + { + await outputStream.FlushAsync(); // async flush, otherwise Seek would do a sync flush + outputStream.Seek(streamStartPos + offset, SeekOrigin.Begin); + } + await outputStream.WriteAsync(fileData.bytes, 0, fileData.bytes.Length); + maxOffsetSeen = Math.Max(maxOffsetSeen, offset + fileData.bytes.Length); + } + catch (Exception) + { + abort = true; + throw; + } + finally + { + writeSem.Release(); + } + } + lock (tasks) tasks.Remove(offset); + return fileData.bytes.Length; } - fileData = fileBase as Upload_File; - if (fileData == null) - throw new ApplicationException("Upload_GetFile returned unsupported " + fileBase.GetType().Name); - await outputStream.WriteAsync(fileData.bytes, 0, fileData.bytes.Length); - fileSize += fileData.bytes.Length; - } while (fileData.bytes.Length == ChunkSize); + } + Task[] remainingTasks; + lock (tasks) remainingTasks = tasks.Values.ToArray(); + await Task.WhenAll(remainingTasks); // wait completion and eventually propagate any task exception await outputStream.FlushAsync(); - return fileData.type; + outputStream.Seek(streamStartPos + maxOffsetSeen, SeekOrigin.Begin); + return fileType; } #endregion