Fix excessive stack usage (#246)

This commit is contained in:
Wizou 2024-04-22 17:28:44 +02:00
parent 8228fede0f
commit b6dbf9564f
4 changed files with 51 additions and 44 deletions

View file

@ -31,10 +31,11 @@ public class MTProtoGenerator : IIncrementalGenerator
if (unit.compilation.GetTypeByMetadataName("TL.IObject") is not { } iobject) return;
var nullables = LoadNullables(layer);
var namespaces = new Dictionary<string, Dictionary<string, string>>(); // namespace,class,methods
var makeTL = new StringBuilder();
var tableTL = new StringBuilder();
var source = new StringBuilder();
source
.AppendLine("using System;")
.AppendLine("using System.Collections.Generic;")
.AppendLine("using System.ComponentModel;")
.AppendLine("using System.IO;")
.AppendLine("using System.Linq;")
@ -42,8 +43,8 @@ public class MTProtoGenerator : IIncrementalGenerator
.AppendLine()
.AppendLine("#pragma warning disable CS0109")
.AppendLine();
makeTL
.AppendLine("\t\tpublic static IObject ReadTL(this BinaryReader reader, uint ctorId = 0) => (ctorId != 0 ? ctorId : reader.ReadUInt32()) switch")
tableTL
.AppendLine("\t\tpublic static readonly Dictionary<uint, Func<BinaryReader, IObject>> Table = new()")
.AppendLine("\t\t{");
foreach (var classDecl in unit.classes)
@ -54,7 +55,7 @@ public class MTProtoGenerator : IIncrementalGenerator
if (tldef == null) continue;
var id = (uint)tldef.ConstructorArguments[0].Value;
var inheritBefore = (bool?)tldef.NamedArguments.FirstOrDefault(k => k.Key == "inheritBefore").Value.Value ?? false;
StringBuilder writeTl = new(), ctorTL = new();
StringBuilder writeTl = new(), readTL = new();
var ns = symbol.BaseType.ContainingNamespace.ToString();
var name = symbol.BaseType.Name;
if (ns != "System")
@ -85,18 +86,18 @@ public class MTProtoGenerator : IIncrementalGenerator
continue;
}
if (id == 0x3072CFA1) // GzipPacked
makeTL.AppendLine($"\t\t\t0x{id:X8} => (IObject)reader.ReadTLGzipped(typeof(IObject)),");
tableTL.AppendLine($"\t\t\t[0x{id:X8}] = reader => (IObject)reader.ReadTLGzipped(typeof(IObject)),");
else if (name != "Null" && (ns != "TL.Methods" || name == "Ping"))
makeTL.AppendLine($"\t\t\t0x{id:X8} => new {(ns == "TL" ? "" : ns + '.')}{name}().ReadTL(reader),");
tableTL.AppendLine($"\t\t\t[0x{id:X8}] = {(ns == "TL" ? "" : ns + '.')}{name}.ReadTL,");
var override_ = symbol.BaseType == object_ ? "" : "override ";
if (name == "Messages_AffectedMessages") override_ = "virtual ";
//if (symbol.Constructors[0].IsImplicitlyDeclared)
// ctorTL.AppendLine($"\t\tpublic {name}() {{ }}");
if (symbol.IsGenericType) name += "<X>";
ctorTL
.AppendLine("\t\t[EditorBrowsable(EditorBrowsableState.Never)]")
.AppendLine($"\t\tpublic new {name} ReadTL(BinaryReader reader)")
.AppendLine("\t\t{");
readTL
.AppendLine($"\t\tpublic static new {name} ReadTL(BinaryReader reader)")
.AppendLine("\t\t{")
.AppendLine($"\t\t\tvar r = new {name}();");
writeTl
.AppendLine("\t\t[EditorBrowsable(EditorBrowsableState.Never)]")
.AppendLine($"\t\tpublic {override_}void WriteTL(BinaryWriter writer)")
@ -109,88 +110,88 @@ public class MTProtoGenerator : IIncrementalGenerator
foreach (var member in members.OfType<IFieldSymbol>())
{
if (member.DeclaredAccessibility != Accessibility.Public || member.IsStatic) continue;
ctorTL.Append("\t\t\t");
readTL.Append("\t\t\t");
writeTl.Append("\t\t\t");
var ifFlag = (int?)member.GetAttributes().FirstOrDefault(a => a.AttributeClass == ifFlagAttribute)?.ConstructorArguments[0].Value;
if (ifFlag != null)
{
var condition = ifFlag < 32 ? $"if (((uint)flags & 0x{1 << ifFlag:X}) != 0) "
: $"if (((uint)flags2 & 0x{1 << (ifFlag - 32):X}) != 0) ";
ctorTL.Append(condition);
writeTl.Append(condition);
readTL.Append(ifFlag < 32 ? $"if (((uint)r.flags & 0x{1 << ifFlag:X}) != 0) "
: $"if (((uint)r.flags2 & 0x{1 << (ifFlag - 32):X}) != 0) ");
writeTl.Append(ifFlag < 32 ? $"if (((uint)flags & 0x{1 << ifFlag:X}) != 0) "
: $"if (((uint)flags2 & 0x{1 << (ifFlag - 32):X}) != 0) ");
}
string memberType = member.Type.ToString();
switch (memberType)
{
case "int":
ctorTL.AppendLine($"{member.Name} = reader.ReadInt32();");
readTL.AppendLine($"r.{member.Name} = reader.ReadInt32();");
writeTl.AppendLine($"writer.Write({member.Name});");
break;
case "long":
ctorTL.AppendLine($"{member.Name} = reader.ReadInt64();");
readTL.AppendLine($"r.{member.Name} = reader.ReadInt64();");
writeTl.AppendLine($"writer.Write({member.Name});");
break;
case "double":
ctorTL.AppendLine($"{member.Name} = reader.ReadDouble();");
readTL.AppendLine($"r.{member.Name} = reader.ReadDouble();");
writeTl.AppendLine($"writer.Write({member.Name});");
break;
case "bool":
ctorTL.AppendLine($"{member.Name} = reader.ReadTLBool();");
readTL.AppendLine($"r.{member.Name} = reader.ReadTLBool();");
writeTl.AppendLine($"writer.Write({member.Name} ? 0x997275B5 : 0xBC799737);");
break;
case "System.DateTime":
ctorTL.AppendLine($"{member.Name} = reader.ReadTLStamp();");
readTL.AppendLine($"r.{member.Name} = reader.ReadTLStamp();");
writeTl.AppendLine($"writer.WriteTLStamp({member.Name});");
break;
case "string":
ctorTL.AppendLine($"{member.Name} = reader.ReadTLString();");
readTL.AppendLine($"r.{member.Name} = reader.ReadTLString();");
writeTl.AppendLine($"writer.WriteTLString({member.Name});");
break;
case "byte[]":
ctorTL.AppendLine($"{member.Name} = reader.ReadTLBytes();");
readTL.AppendLine($"r.{member.Name} = reader.ReadTLBytes();");
writeTl.AppendLine($"writer.WriteTLBytes({member.Name});");
break;
case "TL.Int128":
ctorTL.AppendLine($"{member.Name} = new Int128(reader);");
readTL.AppendLine($"r.{member.Name} = new Int128(reader);");
writeTl.AppendLine($"writer.Write({member.Name});");
break;
case "TL.Int256":
ctorTL.AppendLine($"{member.Name} = new Int256(reader);");
readTL.AppendLine($"r.{member.Name} = new Int256(reader);");
writeTl.AppendLine($"writer.Write({member.Name});");
break;
case "TL._Message[]":
ctorTL.AppendLine($"{member.Name} = reader.ReadTLRawVector<_Message>(0x5BB8E511);");
readTL.AppendLine($"r.{member.Name} = reader.ReadTLRawVector<_Message>(0x5BB8E511);");
writeTl.AppendLine($"writer.WriteTLMessages({member.Name});");
break;
case "TL.IObject": case "TL.IMethod<X>":
ctorTL.AppendLine($"{member.Name} = {(memberType == "TL.IObject" ? "" : $"({memberType})")}reader.ReadTL();");
readTL.AppendLine($"r.{member.Name} = {(memberType == "TL.IObject" ? "" : $"({memberType})")}reader.ReadTLObject();");
writeTl.AppendLine($"{member.Name}.WriteTL(writer);");
break;
case "System.Collections.Generic.Dictionary<long, TL.User>":
ctorTL.AppendLine($"{member.Name} = reader.ReadTLDictionary<User>();");
readTL.AppendLine($"r.{member.Name} = reader.ReadTLDictionary<User>();");
writeTl.AppendLine($"writer.WriteTLVector({member.Name}.Values.ToArray());");
break;
case "System.Collections.Generic.Dictionary<long, TL.ChatBase>":
ctorTL.AppendLine($"{member.Name} = reader.ReadTLDictionary<ChatBase>();");
readTL.AppendLine($"r.{member.Name} = reader.ReadTLDictionary<ChatBase>();");
writeTl.AppendLine($"writer.WriteTLVector({member.Name}.Values.ToArray());");
break;
default:
if (member.Type is IArrayTypeSymbol arrayType)
{
if (name is "FutureSalts")
ctorTL.AppendLine($"{member.Name} = reader.ReadTLRawVector<{memberType.Substring(0, memberType.Length - 2)}>(0x0949D9DC);");
readTL.AppendLine($"r.{member.Name} = reader.ReadTLRawVector<{memberType.Substring(0, memberType.Length - 2)}>(0x0949D9DC);");
else
ctorTL.AppendLine($"{member.Name} = reader.ReadTLVector<{memberType.Substring(0, memberType.Length - 2)}>();");
readTL.AppendLine($"r.{member.Name} = reader.ReadTLVector<{memberType.Substring(0, memberType.Length - 2)}>();");
writeTl.AppendLine($"writer.WriteTLVector({member.Name});");
}
else if (member.Type.BaseType.SpecialType == SpecialType.System_Enum)
{
ctorTL.AppendLine($"{member.Name} = ({memberType})reader.ReadUInt32();");
readTL.AppendLine($"r.{member.Name} = ({memberType})reader.ReadUInt32();");
writeTl.AppendLine($"writer.Write((uint){member.Name});");
}
else if (memberType.StartsWith("TL."))
{
ctorTL.AppendLine($"{member.Name} = ({memberType})reader.ReadTL();");
readTL.AppendLine($"r.{member.Name} = ({memberType})reader.ReadTLObject();");
var nullStr = nullables.TryGetValue(memberType, out uint nullCtor) ? $"0x{nullCtor:X8}" : "Layer.NullCtor";
writeTl.AppendLine($"if ({member.Name} != null) {member.Name}.WriteTL(writer); else writer.Write({nullStr});");
}
@ -199,18 +200,17 @@ public class MTProtoGenerator : IIncrementalGenerator
break;
}
}
ctorTL.AppendLine("\t\t\treturn this;");
ctorTL.AppendLine("\t\t}");
readTL.AppendLine("\t\t\treturn r;");
readTL.AppendLine("\t\t}");
writeTl.AppendLine("\t\t}");
ctorTL.Append(writeTl.ToString());
classes[name] = ctorTL.ToString();
readTL.Append(writeTl.ToString());
classes[name] = readTL.ToString();
}
foreach (var nullable in nullables)
makeTL.AppendLine($"\t\t\t0x{nullable.Value:X8} => null,");
makeTL.AppendLine("\t\t\tvar ctorNb => throw new WTelegram.WTException($\"Cannot find type for ctor #{ctorNb:x}\")");
makeTL.AppendLine("\t\t};");
namespaces["TL"]["Layer"] = makeTL.ToString();
tableTL.AppendLine($"\t\t\t[0x{nullable.Value:X8}] = null,");
tableTL.AppendLine("\t\t};");
namespaces["TL"]["Layer"] = tableTL.ToString();
foreach (var namesp in namespaces)
{
source.Append("namespace ").AppendLine(namesp.Key).Append('{');

View file

@ -17,6 +17,7 @@ namespace TL
internal const uint BadMsgCtor = 0xA7EFF811;
internal const uint GZipedCtor = 0x3072CFA1;
#if !MTPG
[EditorBrowsable(EditorBrowsableState.Never)]
public readonly static Dictionary<uint, Type> Table = new()
{
@ -1315,6 +1316,7 @@ namespace TL
[0xAA48327D] = typeof(Layer8.DecryptedMessageService),
[0x1F814F1F] = typeof(Layer8.DecryptedMessage),
};
#endif
internal readonly static Dictionary<Type, uint> Nullables = new()
{

View file

@ -548,7 +548,10 @@ namespace TL
partial class MessageEntity
{
public string Type { get { var name = GetType().Name; return name[(name.IndexOf("MessageEntity") + 13)..]; } }
public string Type {
get { var name = GetType().Name; return name[(name.IndexOf("MessageEntity") + 13)..]; }
set { if (value != Type) throw new NotSupportedException("Can't change Type. You need to create a new instance of the right TL.MessageEntity* subclass"); }
}
public int Offset { get => offset; set => offset = value; }
public int Length { get => length; set => length = value; }
}

View file

@ -76,10 +76,12 @@ namespace TL
public static IObject ReadTLObject(this BinaryReader reader, uint ctorNb = 0)
{
#if MTPG
return reader.ReadTL(ctorNb);
#else
if (ctorNb == 0) ctorNb = reader.ReadUInt32();
#if MTPG
if (!Layer.Table.TryGetValue(ctorNb, out var ctor))
throw new WTelegram.WTException($"Cannot find type for ctor #{ctorNb:x}");
return ctor?.Invoke(reader);
#else
if (ctorNb == Layer.GZipedCtor) return (IObject)reader.ReadTLGzipped(typeof(IObject));
if (!Layer.Table.TryGetValue(ctorNb, out var type))
throw new WTelegram.WTException($"Cannot find type for ctor #{ctorNb:x}");