diff --git a/generator/MTProtoGenerator.cs b/generator/MTProtoGenerator.cs new file mode 100644 index 0000000..4901bd8 --- /dev/null +++ b/generator/MTProtoGenerator.cs @@ -0,0 +1,235 @@ +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Diagnostics; +using System.Linq; +using System.Text; + +#pragma warning disable RS1024 // Symbols should be compared for equality + +namespace TL.Generator; + +[Generator] +public class MTProtoGenerator : IIncrementalGenerator +{ + public void Initialize(IncrementalGeneratorInitializationContext context) + { + var classDeclarations = context.SyntaxProvider.ForAttributeWithMetadataName("TL.TLDefAttribute", + (_, _) => true, (context, _) => (ClassDeclarationSyntax)context.TargetNode); + var source = context.CompilationProvider.Combine(classDeclarations.Collect()); + context.RegisterSourceOutput(source, Execute); + } + + static void Execute(SourceProductionContext context, (Compilation compilation, ImmutableArray classes) unit) + { + var object_ = unit.compilation.GetSpecialType(SpecialType.System_Object); + if (unit.compilation.GetTypeByMetadataName("TL.TLDefAttribute") is not { } tlDefAttribute) return; + if (unit.compilation.GetTypeByMetadataName("TL.IfFlagAttribute") is not { } ifFlagAttribute) return; + if (unit.compilation.GetTypeByMetadataName("TL.Layer") is not { } layer) return; + if (unit.compilation.GetTypeByMetadataName("TL.IObject") is not { } iobject) return; + var nullables = LoadNullables(layer); + var namespaces = new Dictionary>(); // namespace,class,methods + var readTL = new StringBuilder(); + readTL + .AppendLine("\t\tpublic static IObject ReadTL(this BinaryReader reader, uint ctorId = 0) => (ctorId != 0 ? ctorId : reader.ReadUInt32()) switch") + .AppendLine("\t\t{"); + + foreach (var classDecl in unit.classes) + { + var semanticModel = unit.compilation.GetSemanticModel(classDecl.SyntaxTree); + if (semanticModel.GetDeclaredSymbol(classDecl) is not { } symbol) continue; + var tldef = symbol.GetAttributes().FirstOrDefault(a => a.AttributeClass == tlDefAttribute); + if (tldef == null) continue; + var id = (uint)tldef.ConstructorArguments[0].Value; + var inheritBefore = (bool?)tldef.NamedArguments.FirstOrDefault(k => k.Key == "inheritBefore").Value.Value ?? false; + StringBuilder writeTl = new(), ctorTL = new(); + var ns = symbol.BaseType.ContainingNamespace.ToString(); + var name = symbol.BaseType.Name; + if (ns != "System") + { + if (!namespaces.TryGetValue(ns, out var parentClasses)) namespaces[ns] = parentClasses = []; + parentClasses.TryGetValue(name, out var parentMethods); + if (symbol.BaseType.IsAbstract) + { + if (parentMethods == null) + { + writeTl.AppendLine("\t\tpublic abstract void WriteTL(BinaryWriter writer);"); + parentClasses[name] = writeTl.ToString(); + writeTl.Clear(); + } + } + else if (parentMethods?.Contains(" virtual ") == false) + parentClasses[name] = parentMethods.Replace("public void WriteTL(", "public virtual void WriteTL("); + } + ns = symbol.ContainingNamespace.ToString(); + name = symbol.Name; + if (!namespaces.TryGetValue(ns, out var classes)) namespaces[ns] = classes = []; + if (name is "_Message" or "RpcResult" or "MsgCopy") + { + classes[name] = "\t\tpublic void WriteTL(BinaryWriter writer) => throw new NotSupportedException();"; + continue; + } + if (id == 0x3072CFA1) // GzipPacked + readTL.AppendLine($"\t\t\t0x{id:X8} => reader.ReadTLGzipped(),"); + else if (name != "Null" && (ns != "TL.Methods" || name == "Ping")) + readTL.AppendLine($"\t\t\t0x{id:X8} => new {(ns == "TL" ? "" : ns + '.')}{name}(reader),"); + var override_ = symbol.BaseType == object_ ? "" : "override "; + if (name == "Messages_AffectedMessages") override_ = "virtual "; + if (symbol.Constructors[0].IsImplicitlyDeclared) + ctorTL.AppendLine($"\t\tpublic {name}() {{ }}"); + ctorTL + .AppendLine($"\t\tpublic {name}(BinaryReader reader)") + .AppendLine("\t\t{"); + writeTl + .AppendLine($"\t\tpublic {override_}void WriteTL(BinaryWriter writer)") + .AppendLine("\t\t{") + .AppendLine($"\t\t\twriter.Write(0x{id:X8});"); + var members = symbol.GetMembers().ToList(); + for (var parent = symbol.BaseType; parent != object_; parent = parent.BaseType) + if (inheritBefore) members.InsertRange(0, parent.GetMembers()); + else members.AddRange(parent.GetMembers()); + foreach (var member in members.OfType()) + { + if (member.DeclaredAccessibility != Accessibility.Public || member.IsStatic) continue; + ctorTL.Append("\t\t\t"); + writeTl.Append("\t\t\t"); + var ifFlag = (int?)member.GetAttributes().FirstOrDefault(a => a.AttributeClass == ifFlagAttribute)?.ConstructorArguments[0].Value; + if (ifFlag != null) + { + var condition = ifFlag < 32 ? $"if (((uint)flags & 0x{1 << ifFlag:X}) != 0) " + : $"if (((uint)flags2 & 0x{1 << (ifFlag - 32):X}) != 0) "; + ctorTL.Append(condition); + writeTl.Append(condition); + } + string memberType = member.Type.ToString(); + switch (memberType) + { + case "int": + ctorTL.AppendLine($"{member.Name} = reader.ReadInt32();"); + writeTl.AppendLine($"writer.Write({member.Name});"); + break; + case "long": + ctorTL.AppendLine($"{member.Name} = reader.ReadInt64();"); + writeTl.AppendLine($"writer.Write({member.Name});"); + break; + case "double": + ctorTL.AppendLine($"{member.Name} = reader.ReadDouble();"); + writeTl.AppendLine($"writer.Write({member.Name});"); + break; + case "bool": + ctorTL.AppendLine($"{member.Name} = reader.ReadTLBool();"); + writeTl.AppendLine($"writer.Write({member.Name} ? 0x997275B5 : 0xBC799737);"); + break; + case "System.DateTime": + ctorTL.AppendLine($"{member.Name} = reader.ReadTLStamp();"); + writeTl.AppendLine($"writer.WriteTLStamp({member.Name});"); + break; + case "string": + ctorTL.AppendLine($"{member.Name} = reader.ReadTLString();"); + writeTl.AppendLine($"writer.WriteTLString({member.Name});"); + break; + case "byte[]": + ctorTL.AppendLine($"{member.Name} = reader.ReadTLBytes();"); + writeTl.AppendLine($"writer.WriteTLBytes({member.Name});"); + break; + case "TL.Int128": + ctorTL.AppendLine($"{member.Name} = new Int128(reader);"); + writeTl.AppendLine($"writer.Write({member.Name});"); + break; + case "TL.Int256": + ctorTL.AppendLine($"{member.Name} = new Int256(reader);"); + writeTl.AppendLine($"writer.Write({member.Name});"); + break; + case "TL._Message[]": + ctorTL.AppendLine($"throw new NotSupportedException();"); + writeTl.AppendLine($"writer.WriteTLMessages({member.Name});"); + break; + case "TL.IObject": case "TL.IMethod": + ctorTL.AppendLine($"{member.Name} = {(memberType == "TL.IObject" ? "" : $"({memberType})")}reader.ReadTL();"); + writeTl.AppendLine($"{member.Name}.WriteTL(writer);"); + break; + case "System.Collections.Generic.Dictionary": + ctorTL.AppendLine($"{member.Name} = reader.ReadTLDictionary();"); + writeTl.AppendLine($"writer.WriteTLVector({member.Name}.Values.ToArray());"); + break; + case "System.Collections.Generic.Dictionary": + ctorTL.AppendLine($"{member.Name} = reader.ReadTLDictionary();"); + writeTl.AppendLine($"writer.WriteTLVector({member.Name}.Values.ToArray());"); + break; + default: + if (member.Type is IArrayTypeSymbol arrayType) + { + if (name is "FutureSalts") + ctorTL.AppendLine($"{member.Name} = reader.ReadTLRawVector<{memberType.Substring(0, memberType.Length - 2)}>(0x0949D9DC);"); + else + ctorTL.AppendLine($"{member.Name} = reader.ReadTLVector<{memberType.Substring(0, memberType.Length - 2)}>();"); + writeTl.AppendLine($"writer.WriteTLVector({member.Name});"); + } + else if (member.Type.BaseType.SpecialType == SpecialType.System_Enum) + { + ctorTL.AppendLine($"{member.Name} = ({memberType})reader.ReadUInt32();"); + writeTl.AppendLine($"writer.Write((uint){member.Name});"); + } + else if (memberType.StartsWith("TL.")) + { + ctorTL.AppendLine($"{member.Name} = ({memberType})reader.ReadTL();"); + var nullStr = nullables.TryGetValue(memberType, out uint nullCtor) ? $"0x{nullCtor:X8}" : "Layer.NullCtor"; + writeTl.AppendLine($"if ({member.Name} != null) {member.Name}.WriteTL(writer); else writer.Write({nullStr});"); + } + else + writeTl.AppendLine($"Cannot serialize {memberType}"); + break; + } + } + ctorTL.AppendLine("\t\t}"); + writeTl.AppendLine("\t\t}"); + ctorTL.Append(writeTl.ToString()); + if (symbol.IsGenericType) name += ""; + classes[name] = ctorTL.ToString(); + } + + var source = new StringBuilder(); + source + .AppendLine("using System;") + .AppendLine("using System.IO;") + .AppendLine("using System.Linq;") + .AppendLine("using TL;") + .AppendLine(); + foreach (var nullable in nullables) + readTL.AppendLine($"\t\t\t0x{nullable.Value:X8} => null,"); + readTL.AppendLine("\t\t\tvar ctorNb => throw new Exception($\"Cannot find type for ctor #{ctorNb:x}\")"); + readTL.AppendLine("\t\t};"); + namespaces["TL"]["Layer"] = readTL.ToString(); + foreach (var namesp in namespaces) + { + source.Append("namespace ").AppendLine(namesp.Key).Append('{'); + foreach (var method in namesp.Value) + source.AppendLine().Append("\tpartial class ").AppendLine(method.Key).AppendLine("\t{").Append(method.Value).AppendLine("\t}"); + source.AppendLine("}").AppendLine(); + } + string text = source.ToString(); + Debug.Write(text); + context.AddSource("TL.Generated.cs", text); + } + + private static Dictionary LoadNullables(INamedTypeSymbol layer) + { + var nullables = layer.GetMembers("Nullables").Single() as IFieldSymbol; + var initializer = nullables.DeclaringSyntaxReferences[0].GetSyntax().ToString(); + var table = new Dictionary(); + foreach (var line in initializer.Split('\n')) + { + int index = line.IndexOf("[typeof("); + if (index == -1) continue; + int index2 = line.IndexOf(')', index += 8); + string className = "TL." + line.Substring(index, index2 - index); + index = line.IndexOf("= 0x", index2); + if (index == -1) continue; + index2 = line.IndexOf(',', index += 4); + table[className] = uint.Parse(line.Substring(index, index2 - index), System.Globalization.NumberStyles.HexNumber); + } + return table; + } +} diff --git a/generator/MTProtoGenerator.csproj b/generator/MTProtoGenerator.csproj new file mode 100644 index 0000000..881f2da --- /dev/null +++ b/generator/MTProtoGenerator.csproj @@ -0,0 +1,20 @@ + + + netstandard2.0 + true + true + true + True + latest + + + + + + + + + + \ No newline at end of file diff --git a/src/Client.cs b/src/Client.cs index e5da354..af57f0b 100644 --- a/src/Client.cs +++ b/src/Client.cs @@ -570,12 +570,11 @@ namespace WTelegram if (peek == Layer.RpcErrorCtor) result = reader.ReadTLObject(Layer.RpcErrorCtor); else if (peek == Layer.GZipedCtor) - using (var gzipReader = new BinaryReader(new GZipStream(new MemoryStream(reader.ReadTLBytes()), CompressionMode.Decompress))) - result = gzipReader.ReadTLValue(rpc.type); + result = reader.ReadTLGzipped(); else { reader.BaseStream.Position -= 4; - result = reader.ReadTLValue(rpc.type); + result = reader.ReadTLVector(rpc.type); } } if (rpc.type.IsEnum) result = Enum.ToObject(rpc.type, result); diff --git a/src/Helpers.cs b/src/Helpers.cs index e76c116..e716078 100644 --- a/src/Helpers.cs +++ b/src/Helpers.cs @@ -4,12 +4,11 @@ using System.IO; using System.Numerics; using System.Reflection; using System.Text.Json; +using System.Text.Json.Serialization; using System.Threading; using System.Threading.Tasks; #if NET8_0_OR_GREATER -using System.Text.Json.Serialization; -using System.Text.Json.Serialization.Metadata; [JsonSerializable(typeof(WTelegram.Session))] internal partial class WTelegramContext : JsonSerializerContext { } #endif @@ -24,9 +23,9 @@ namespace WTelegram /// For serializing indented Json with fields included public static readonly JsonSerializerOptions JsonOptions = new() { IncludeFields = true, WriteIndented = true, #if NET8_0_OR_GREATER - TypeInfoResolver = JsonSerializer.IsReflectionEnabledByDefault ? new DefaultJsonTypeInfoResolver() : WTelegramContext.Default, + TypeInfoResolver = JsonSerializer.IsReflectionEnabledByDefault ? null : WTelegramContext.Default, #endif - IgnoreReadOnlyProperties = true, DefaultIgnoreCondition = System.Text.Json.Serialization.JsonIgnoreCondition.WhenWritingNull }; + IgnoreReadOnlyProperties = true, DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull }; private static readonly ConsoleColor[] LogLevelToColor = [ ConsoleColor.DarkGray, ConsoleColor.DarkCyan, ConsoleColor.Cyan, ConsoleColor.Yellow, ConsoleColor.Red, ConsoleColor.Magenta, ConsoleColor.DarkBlue ]; diff --git a/src/TL.Extensions.cs b/src/TL.Extensions.cs index 428f243..8aa400a 100644 --- a/src/TL.Extensions.cs +++ b/src/TL.Extensions.cs @@ -88,6 +88,9 @@ namespace TL } return null; } +#if MTPG + public override void WriteTL(System.IO.BinaryWriter writer) => throw new NotImplementedException(); +#endif } /// Accumulate users/chats found in this structure in your dictionaries, ignoring Min constructors when the full object is already stored diff --git a/src/TL.cs b/src/TL.cs index 6d4541a..2eee871 100644 --- a/src/TL.cs +++ b/src/TL.cs @@ -11,7 +11,11 @@ using System.Text; namespace TL { +#if MTPG + public interface IObject { void WriteTL(BinaryWriter writer); } +#else public interface IObject { } +#endif public interface IMethod : IObject { } public interface IPeerResolver { IPeerInfo UserOrChat(Peer peer); } @@ -39,6 +43,7 @@ namespace TL public sealed partial class ReactorError : IObject { public Exception Exception; + public void WriteTL(BinaryWriter writer) => throw new NotSupportedException(); } public static class Serialization @@ -46,6 +51,9 @@ namespace TL public static void WriteTLObject(this BinaryWriter writer, T obj) where T : IObject { if (obj == null) { writer.WriteTLNull(typeof(T)); return; } +#if MTPG + obj.WriteTL(writer); +#else var type = obj.GetType(); var tlDef = type.GetCustomAttribute(); var ctorNb = tlDef.CtorNb; @@ -63,14 +71,16 @@ namespace TL if (field.Name == "flags") flags = (uint)value; else if (field.Name == "flags2") flags |= (ulong)(uint)value << 32; } +#endif } public static IObject ReadTLObject(this BinaryReader reader, uint ctorNb = 0) { +#if MTPG + return reader.ReadTL(ctorNb); +#else if (ctorNb == 0) ctorNb = reader.ReadUInt32(); - if (ctorNb == Layer.GZipedCtor) - using (var gzipReader = new BinaryReader(new GZipStream(new MemoryStream(reader.ReadTLBytes()), CompressionMode.Decompress))) - return ReadTLObject(gzipReader); + if (ctorNb == Layer.GZipedCtor) return reader.ReadTLGzipped(); if (!Layer.Table.TryGetValue(ctorNb, out var type)) throw new WTelegram.WTException($"Cannot find type for ctor #{ctorNb:x}"); if (type == null) return null; // nullable ctor (class meaning is associated with null) @@ -90,6 +100,7 @@ namespace TL else if (field.Name == "flags2") flags |= (ulong)(uint)value << 32; } return (IObject)obj; +#endif } internal static void WriteTLValue(this BinaryWriter writer, object value, Type valueType) @@ -178,17 +189,6 @@ namespace TL } } - internal static void WriteTLVector(this BinaryWriter writer, Array array) - { - writer.Write(Layer.VectorCtor); - if (array == null) { writer.Write(0); return; } - int count = array.Length; - writer.Write(count); - var elementType = array.GetType().GetElementType(); - for (int i = 0; i < count; i++) - writer.WriteTLValue(array.GetValue(i), elementType); - } - internal static void WriteTLMessages(this BinaryWriter writer, _Message[] messages) { writer.Write(messages.Length); @@ -209,6 +209,38 @@ namespace TL } } + internal static void WriteTLVector(this BinaryWriter writer, Array array) + { + writer.Write(Layer.VectorCtor); + if (array == null) { writer.Write(0); return; } + int count = array.Length; + writer.Write(count); + var elementType = array.GetType().GetElementType(); + for (int i = 0; i < count; i++) + writer.WriteTLValue(array.GetValue(i), elementType); + } + + internal static T[] ReadTLRawVector(this BinaryReader reader, uint ctorNb) + { + int count = reader.ReadInt32(); + var array = new T[count]; + for (int i = 0; i < count; i++) + array[i] = (T)reader.ReadTLObject(ctorNb); + return array; + } + + internal static T[] ReadTLVector(this BinaryReader reader) + { + var elementType = typeof(T); + if (reader.ReadUInt32() is not Layer.VectorCtor and uint ctorNb) + throw new WTelegram.WTException($"Cannot deserialize {elementType.Name}[] with ctor #{ctorNb:x}"); + int count = reader.ReadInt32(); + var array = new T[count]; + for (int i = 0; i < count; i++) + array[i] = (T)reader.ReadTLValue(elementType); + return array; + } + internal static Array ReadTLVector(this BinaryReader reader, Type type) { var elementType = type.GetElementType(); @@ -240,14 +272,13 @@ namespace TL internal static Dictionary ReadTLDictionary(this BinaryReader reader) where T : class, IPeerInfo { uint ctorNb = reader.ReadUInt32(); - var elementType = typeof(T); if (ctorNb != Layer.VectorCtor) - throw new WTelegram.WTException($"Cannot deserialize Vector<{elementType.Name}> with ctor #{ctorNb:x}"); + throw new WTelegram.WTException($"Cannot deserialize Vector<{typeof(T).Name}> with ctor #{ctorNb:x}"); int count = reader.ReadInt32(); var dict = new Dictionary(count); for (int i = 0; i < count; i++) { - var value = (T)reader.ReadTLValue(elementType); + var value = (T)reader.ReadTLObject(); dict[value.ID] = value is UserEmpty ? null : value; } return dict; @@ -317,6 +348,19 @@ namespace TL writer.Write(0); // null arrays/strings are serialized as empty } + internal static IObject ReadTLGzipped(this BinaryReader reader) + { + using var gzipReader = new BinaryReader(new GZipStream(new MemoryStream(reader.ReadTLBytes()), CompressionMode.Decompress)); + return ReadTLObject(gzipReader); + } + + internal static bool ReadTLBool(this BinaryReader reader) => reader.ReadUInt32() switch + { + 0x997275b5 => true, + 0xbc799737 => false, + var value => throw new WTelegram.WTException($"Invalid boolean value #{value:x}") + }; + #if DEBUG private static void ShouldntBeHere() => System.Diagnostics.Debugger.Break(); #else @@ -356,6 +400,9 @@ namespace TL { public Messages_AffectedMessages affected; public override (long, int, int) GetMBox() => (0, affected.pts, affected.pts_count); +#if MTPG + public override void WriteTL(BinaryWriter writer) => throw new NotSupportedException(); +#endif } // Below TL types are commented "parsed manually" from https://github.com/telegramdesktop/tdesktop/blob/dev/Telegram/Resources/tl/mtproto.tl diff --git a/src/WTelegramClient.csproj b/src/WTelegramClient.csproj index d2b078c..73e5418 100644 --- a/src/WTelegramClient.csproj +++ b/src/WTelegramClient.csproj @@ -25,7 +25,7 @@ README.md $(ReleaseNotes.Replace("|", "%0D%0A").Replace(" - ","%0D%0A- ").Replace(" ", "%0D%0A%0D%0A")) 0419;1573;1591;NETSDK1138 - TRACE;OBFUSCATION + TRACE;OBFUSCATION;MTPG @@ -42,6 +42,10 @@ --> + + + +