From b6dbf9564f26a58af8bc1edb9652d101f194beb4 Mon Sep 17 00:00:00 2001 From: Wizou <11647984+wiz0u@users.noreply.github.com> Date: Mon, 22 Apr 2024 17:28:44 +0200 Subject: [PATCH] Fix excessive stack usage (#246) --- generator/MTProtoGenerator.cs | 80 +++++++++++++++++------------------ src/TL.Table.cs | 2 + src/TL.Xtended.cs | 5 ++- src/TL.cs | 8 ++-- 4 files changed, 51 insertions(+), 44 deletions(-) diff --git a/generator/MTProtoGenerator.cs b/generator/MTProtoGenerator.cs index 14db231..06ad7d1 100644 --- a/generator/MTProtoGenerator.cs +++ b/generator/MTProtoGenerator.cs @@ -31,10 +31,11 @@ public class MTProtoGenerator : IIncrementalGenerator if (unit.compilation.GetTypeByMetadataName("TL.IObject") is not { } iobject) return; var nullables = LoadNullables(layer); var namespaces = new Dictionary>(); // namespace,class,methods - var makeTL = new StringBuilder(); + var tableTL = new StringBuilder(); var source = new StringBuilder(); source .AppendLine("using System;") + .AppendLine("using System.Collections.Generic;") .AppendLine("using System.ComponentModel;") .AppendLine("using System.IO;") .AppendLine("using System.Linq;") @@ -42,8 +43,8 @@ public class MTProtoGenerator : IIncrementalGenerator .AppendLine() .AppendLine("#pragma warning disable CS0109") .AppendLine(); - makeTL - .AppendLine("\t\tpublic static IObject ReadTL(this BinaryReader reader, uint ctorId = 0) => (ctorId != 0 ? ctorId : reader.ReadUInt32()) switch") + tableTL + .AppendLine("\t\tpublic static readonly Dictionary> Table = new()") .AppendLine("\t\t{"); foreach (var classDecl in unit.classes) @@ -54,7 +55,7 @@ public class MTProtoGenerator : IIncrementalGenerator if (tldef == null) continue; var id = (uint)tldef.ConstructorArguments[0].Value; var inheritBefore = (bool?)tldef.NamedArguments.FirstOrDefault(k => k.Key == "inheritBefore").Value.Value ?? false; - StringBuilder writeTl = new(), ctorTL = new(); + StringBuilder writeTl = new(), readTL = new(); var ns = symbol.BaseType.ContainingNamespace.ToString(); var name = symbol.BaseType.Name; if (ns != "System") @@ -85,18 +86,18 @@ public class MTProtoGenerator : IIncrementalGenerator continue; } if (id == 0x3072CFA1) // GzipPacked - makeTL.AppendLine($"\t\t\t0x{id:X8} => (IObject)reader.ReadTLGzipped(typeof(IObject)),"); + tableTL.AppendLine($"\t\t\t[0x{id:X8}] = reader => (IObject)reader.ReadTLGzipped(typeof(IObject)),"); else if (name != "Null" && (ns != "TL.Methods" || name == "Ping")) - makeTL.AppendLine($"\t\t\t0x{id:X8} => new {(ns == "TL" ? "" : ns + '.')}{name}().ReadTL(reader),"); + tableTL.AppendLine($"\t\t\t[0x{id:X8}] = {(ns == "TL" ? "" : ns + '.')}{name}.ReadTL,"); var override_ = symbol.BaseType == object_ ? "" : "override "; if (name == "Messages_AffectedMessages") override_ = "virtual "; //if (symbol.Constructors[0].IsImplicitlyDeclared) // ctorTL.AppendLine($"\t\tpublic {name}() {{ }}"); if (symbol.IsGenericType) name += ""; - ctorTL - .AppendLine("\t\t[EditorBrowsable(EditorBrowsableState.Never)]") - .AppendLine($"\t\tpublic new {name} ReadTL(BinaryReader reader)") - .AppendLine("\t\t{"); + readTL + .AppendLine($"\t\tpublic static new {name} ReadTL(BinaryReader reader)") + .AppendLine("\t\t{") + .AppendLine($"\t\t\tvar r = new {name}();"); writeTl .AppendLine("\t\t[EditorBrowsable(EditorBrowsableState.Never)]") .AppendLine($"\t\tpublic {override_}void WriteTL(BinaryWriter writer)") @@ -109,88 +110,88 @@ public class MTProtoGenerator : IIncrementalGenerator foreach (var member in members.OfType()) { if (member.DeclaredAccessibility != Accessibility.Public || member.IsStatic) continue; - ctorTL.Append("\t\t\t"); + readTL.Append("\t\t\t"); writeTl.Append("\t\t\t"); var ifFlag = (int?)member.GetAttributes().FirstOrDefault(a => a.AttributeClass == ifFlagAttribute)?.ConstructorArguments[0].Value; if (ifFlag != null) { - var condition = ifFlag < 32 ? $"if (((uint)flags & 0x{1 << ifFlag:X}) != 0) " - : $"if (((uint)flags2 & 0x{1 << (ifFlag - 32):X}) != 0) "; - ctorTL.Append(condition); - writeTl.Append(condition); + readTL.Append(ifFlag < 32 ? $"if (((uint)r.flags & 0x{1 << ifFlag:X}) != 0) " + : $"if (((uint)r.flags2 & 0x{1 << (ifFlag - 32):X}) != 0) "); + writeTl.Append(ifFlag < 32 ? $"if (((uint)flags & 0x{1 << ifFlag:X}) != 0) " + : $"if (((uint)flags2 & 0x{1 << (ifFlag - 32):X}) != 0) "); } string memberType = member.Type.ToString(); switch (memberType) { case "int": - ctorTL.AppendLine($"{member.Name} = reader.ReadInt32();"); + readTL.AppendLine($"r.{member.Name} = reader.ReadInt32();"); writeTl.AppendLine($"writer.Write({member.Name});"); break; case "long": - ctorTL.AppendLine($"{member.Name} = reader.ReadInt64();"); + readTL.AppendLine($"r.{member.Name} = reader.ReadInt64();"); writeTl.AppendLine($"writer.Write({member.Name});"); break; case "double": - ctorTL.AppendLine($"{member.Name} = reader.ReadDouble();"); + readTL.AppendLine($"r.{member.Name} = reader.ReadDouble();"); writeTl.AppendLine($"writer.Write({member.Name});"); break; case "bool": - ctorTL.AppendLine($"{member.Name} = reader.ReadTLBool();"); + readTL.AppendLine($"r.{member.Name} = reader.ReadTLBool();"); writeTl.AppendLine($"writer.Write({member.Name} ? 0x997275B5 : 0xBC799737);"); break; case "System.DateTime": - ctorTL.AppendLine($"{member.Name} = reader.ReadTLStamp();"); + readTL.AppendLine($"r.{member.Name} = reader.ReadTLStamp();"); writeTl.AppendLine($"writer.WriteTLStamp({member.Name});"); break; case "string": - ctorTL.AppendLine($"{member.Name} = reader.ReadTLString();"); + readTL.AppendLine($"r.{member.Name} = reader.ReadTLString();"); writeTl.AppendLine($"writer.WriteTLString({member.Name});"); break; case "byte[]": - ctorTL.AppendLine($"{member.Name} = reader.ReadTLBytes();"); + readTL.AppendLine($"r.{member.Name} = reader.ReadTLBytes();"); writeTl.AppendLine($"writer.WriteTLBytes({member.Name});"); break; case "TL.Int128": - ctorTL.AppendLine($"{member.Name} = new Int128(reader);"); + readTL.AppendLine($"r.{member.Name} = new Int128(reader);"); writeTl.AppendLine($"writer.Write({member.Name});"); break; case "TL.Int256": - ctorTL.AppendLine($"{member.Name} = new Int256(reader);"); + readTL.AppendLine($"r.{member.Name} = new Int256(reader);"); writeTl.AppendLine($"writer.Write({member.Name});"); break; case "TL._Message[]": - ctorTL.AppendLine($"{member.Name} = reader.ReadTLRawVector<_Message>(0x5BB8E511);"); + readTL.AppendLine($"r.{member.Name} = reader.ReadTLRawVector<_Message>(0x5BB8E511);"); writeTl.AppendLine($"writer.WriteTLMessages({member.Name});"); break; case "TL.IObject": case "TL.IMethod": - ctorTL.AppendLine($"{member.Name} = {(memberType == "TL.IObject" ? "" : $"({memberType})")}reader.ReadTL();"); + readTL.AppendLine($"r.{member.Name} = {(memberType == "TL.IObject" ? "" : $"({memberType})")}reader.ReadTLObject();"); writeTl.AppendLine($"{member.Name}.WriteTL(writer);"); break; case "System.Collections.Generic.Dictionary": - ctorTL.AppendLine($"{member.Name} = reader.ReadTLDictionary();"); + readTL.AppendLine($"r.{member.Name} = reader.ReadTLDictionary();"); writeTl.AppendLine($"writer.WriteTLVector({member.Name}.Values.ToArray());"); break; case "System.Collections.Generic.Dictionary": - ctorTL.AppendLine($"{member.Name} = reader.ReadTLDictionary();"); + readTL.AppendLine($"r.{member.Name} = reader.ReadTLDictionary();"); writeTl.AppendLine($"writer.WriteTLVector({member.Name}.Values.ToArray());"); break; default: if (member.Type is IArrayTypeSymbol arrayType) { if (name is "FutureSalts") - ctorTL.AppendLine($"{member.Name} = reader.ReadTLRawVector<{memberType.Substring(0, memberType.Length - 2)}>(0x0949D9DC);"); + readTL.AppendLine($"r.{member.Name} = reader.ReadTLRawVector<{memberType.Substring(0, memberType.Length - 2)}>(0x0949D9DC);"); else - ctorTL.AppendLine($"{member.Name} = reader.ReadTLVector<{memberType.Substring(0, memberType.Length - 2)}>();"); + readTL.AppendLine($"r.{member.Name} = reader.ReadTLVector<{memberType.Substring(0, memberType.Length - 2)}>();"); writeTl.AppendLine($"writer.WriteTLVector({member.Name});"); } else if (member.Type.BaseType.SpecialType == SpecialType.System_Enum) { - ctorTL.AppendLine($"{member.Name} = ({memberType})reader.ReadUInt32();"); + readTL.AppendLine($"r.{member.Name} = ({memberType})reader.ReadUInt32();"); writeTl.AppendLine($"writer.Write((uint){member.Name});"); } else if (memberType.StartsWith("TL.")) { - ctorTL.AppendLine($"{member.Name} = ({memberType})reader.ReadTL();"); + readTL.AppendLine($"r.{member.Name} = ({memberType})reader.ReadTLObject();"); var nullStr = nullables.TryGetValue(memberType, out uint nullCtor) ? $"0x{nullCtor:X8}" : "Layer.NullCtor"; writeTl.AppendLine($"if ({member.Name} != null) {member.Name}.WriteTL(writer); else writer.Write({nullStr});"); } @@ -199,18 +200,17 @@ public class MTProtoGenerator : IIncrementalGenerator break; } } - ctorTL.AppendLine("\t\t\treturn this;"); - ctorTL.AppendLine("\t\t}"); + readTL.AppendLine("\t\t\treturn r;"); + readTL.AppendLine("\t\t}"); writeTl.AppendLine("\t\t}"); - ctorTL.Append(writeTl.ToString()); - classes[name] = ctorTL.ToString(); + readTL.Append(writeTl.ToString()); + classes[name] = readTL.ToString(); } foreach (var nullable in nullables) - makeTL.AppendLine($"\t\t\t0x{nullable.Value:X8} => null,"); - makeTL.AppendLine("\t\t\tvar ctorNb => throw new WTelegram.WTException($\"Cannot find type for ctor #{ctorNb:x}\")"); - makeTL.AppendLine("\t\t};"); - namespaces["TL"]["Layer"] = makeTL.ToString(); + tableTL.AppendLine($"\t\t\t[0x{nullable.Value:X8}] = null,"); + tableTL.AppendLine("\t\t};"); + namespaces["TL"]["Layer"] = tableTL.ToString(); foreach (var namesp in namespaces) { source.Append("namespace ").AppendLine(namesp.Key).Append('{'); diff --git a/src/TL.Table.cs b/src/TL.Table.cs index e4c5faf..0e6ea71 100644 --- a/src/TL.Table.cs +++ b/src/TL.Table.cs @@ -17,6 +17,7 @@ namespace TL internal const uint BadMsgCtor = 0xA7EFF811; internal const uint GZipedCtor = 0x3072CFA1; +#if !MTPG [EditorBrowsable(EditorBrowsableState.Never)] public readonly static Dictionary Table = new() { @@ -1315,6 +1316,7 @@ namespace TL [0xAA48327D] = typeof(Layer8.DecryptedMessageService), [0x1F814F1F] = typeof(Layer8.DecryptedMessage), }; +#endif internal readonly static Dictionary Nullables = new() { diff --git a/src/TL.Xtended.cs b/src/TL.Xtended.cs index 28d04ed..0433a0a 100644 --- a/src/TL.Xtended.cs +++ b/src/TL.Xtended.cs @@ -548,7 +548,10 @@ namespace TL partial class MessageEntity { - public string Type { get { var name = GetType().Name; return name[(name.IndexOf("MessageEntity") + 13)..]; } } + public string Type { + get { var name = GetType().Name; return name[(name.IndexOf("MessageEntity") + 13)..]; } + set { if (value != Type) throw new NotSupportedException("Can't change Type. You need to create a new instance of the right TL.MessageEntity* subclass"); } + } public int Offset { get => offset; set => offset = value; } public int Length { get => length; set => length = value; } } diff --git a/src/TL.cs b/src/TL.cs index 209a642..6884b25 100644 --- a/src/TL.cs +++ b/src/TL.cs @@ -76,10 +76,12 @@ namespace TL public static IObject ReadTLObject(this BinaryReader reader, uint ctorNb = 0) { -#if MTPG - return reader.ReadTL(ctorNb); -#else if (ctorNb == 0) ctorNb = reader.ReadUInt32(); +#if MTPG + if (!Layer.Table.TryGetValue(ctorNb, out var ctor)) + throw new WTelegram.WTException($"Cannot find type for ctor #{ctorNb:x}"); + return ctor?.Invoke(reader); +#else if (ctorNb == Layer.GZipedCtor) return (IObject)reader.ReadTLGzipped(typeof(IObject)); if (!Layer.Table.TryGetValue(ctorNb, out var type)) throw new WTelegram.WTException($"Cannot find type for ctor #{ctorNb:x}");