Using a source generator to make the library compatible with NativeAOT trimming.

This commit is contained in:
Wizou 2024-03-28 12:13:56 +01:00
parent 55feb857d7
commit 3918e68945
7 changed files with 332 additions and 25 deletions

View file

@ -0,0 +1,235 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Linq;
using System.Text;
#pragma warning disable RS1024 // Symbols should be compared for equality
namespace TL.Generator;
[Generator]
public class MTProtoGenerator : IIncrementalGenerator
{
public void Initialize(IncrementalGeneratorInitializationContext context)
{
var classDeclarations = context.SyntaxProvider.ForAttributeWithMetadataName("TL.TLDefAttribute",
(_, _) => true, (context, _) => (ClassDeclarationSyntax)context.TargetNode);
var source = context.CompilationProvider.Combine(classDeclarations.Collect());
context.RegisterSourceOutput(source, Execute);
}
static void Execute(SourceProductionContext context, (Compilation compilation, ImmutableArray<ClassDeclarationSyntax> classes) unit)
{
var object_ = unit.compilation.GetSpecialType(SpecialType.System_Object);
if (unit.compilation.GetTypeByMetadataName("TL.TLDefAttribute") is not { } tlDefAttribute) return;
if (unit.compilation.GetTypeByMetadataName("TL.IfFlagAttribute") is not { } ifFlagAttribute) return;
if (unit.compilation.GetTypeByMetadataName("TL.Layer") is not { } layer) return;
if (unit.compilation.GetTypeByMetadataName("TL.IObject") is not { } iobject) return;
var nullables = LoadNullables(layer);
var namespaces = new Dictionary<string, Dictionary<string, string>>(); // namespace,class,methods
var readTL = new StringBuilder();
readTL
.AppendLine("\t\tpublic static IObject ReadTL(this BinaryReader reader, uint ctorId = 0) => (ctorId != 0 ? ctorId : reader.ReadUInt32()) switch")
.AppendLine("\t\t{");
foreach (var classDecl in unit.classes)
{
var semanticModel = unit.compilation.GetSemanticModel(classDecl.SyntaxTree);
if (semanticModel.GetDeclaredSymbol(classDecl) is not { } symbol) continue;
var tldef = symbol.GetAttributes().FirstOrDefault(a => a.AttributeClass == tlDefAttribute);
if (tldef == null) continue;
var id = (uint)tldef.ConstructorArguments[0].Value;
var inheritBefore = (bool?)tldef.NamedArguments.FirstOrDefault(k => k.Key == "inheritBefore").Value.Value ?? false;
StringBuilder writeTl = new(), ctorTL = new();
var ns = symbol.BaseType.ContainingNamespace.ToString();
var name = symbol.BaseType.Name;
if (ns != "System")
{
if (!namespaces.TryGetValue(ns, out var parentClasses)) namespaces[ns] = parentClasses = [];
parentClasses.TryGetValue(name, out var parentMethods);
if (symbol.BaseType.IsAbstract)
{
if (parentMethods == null)
{
writeTl.AppendLine("\t\tpublic abstract void WriteTL(BinaryWriter writer);");
parentClasses[name] = writeTl.ToString();
writeTl.Clear();
}
}
else if (parentMethods?.Contains(" virtual ") == false)
parentClasses[name] = parentMethods.Replace("public void WriteTL(", "public virtual void WriteTL(");
}
ns = symbol.ContainingNamespace.ToString();
name = symbol.Name;
if (!namespaces.TryGetValue(ns, out var classes)) namespaces[ns] = classes = [];
if (name is "_Message" or "RpcResult" or "MsgCopy")
{
classes[name] = "\t\tpublic void WriteTL(BinaryWriter writer) => throw new NotSupportedException();";
continue;
}
if (id == 0x3072CFA1) // GzipPacked
readTL.AppendLine($"\t\t\t0x{id:X8} => reader.ReadTLGzipped(),");
else if (name != "Null" && (ns != "TL.Methods" || name == "Ping"))
readTL.AppendLine($"\t\t\t0x{id:X8} => new {(ns == "TL" ? "" : ns + '.')}{name}(reader),");
var override_ = symbol.BaseType == object_ ? "" : "override ";
if (name == "Messages_AffectedMessages") override_ = "virtual ";
if (symbol.Constructors[0].IsImplicitlyDeclared)
ctorTL.AppendLine($"\t\tpublic {name}() {{ }}");
ctorTL
.AppendLine($"\t\tpublic {name}(BinaryReader reader)")
.AppendLine("\t\t{");
writeTl
.AppendLine($"\t\tpublic {override_}void WriteTL(BinaryWriter writer)")
.AppendLine("\t\t{")
.AppendLine($"\t\t\twriter.Write(0x{id:X8});");
var members = symbol.GetMembers().ToList();
for (var parent = symbol.BaseType; parent != object_; parent = parent.BaseType)
if (inheritBefore) members.InsertRange(0, parent.GetMembers());
else members.AddRange(parent.GetMembers());
foreach (var member in members.OfType<IFieldSymbol>())
{
if (member.DeclaredAccessibility != Accessibility.Public || member.IsStatic) continue;
ctorTL.Append("\t\t\t");
writeTl.Append("\t\t\t");
var ifFlag = (int?)member.GetAttributes().FirstOrDefault(a => a.AttributeClass == ifFlagAttribute)?.ConstructorArguments[0].Value;
if (ifFlag != null)
{
var condition = ifFlag < 32 ? $"if (((uint)flags & 0x{1 << ifFlag:X}) != 0) "
: $"if (((uint)flags2 & 0x{1 << (ifFlag - 32):X}) != 0) ";
ctorTL.Append(condition);
writeTl.Append(condition);
}
string memberType = member.Type.ToString();
switch (memberType)
{
case "int":
ctorTL.AppendLine($"{member.Name} = reader.ReadInt32();");
writeTl.AppendLine($"writer.Write({member.Name});");
break;
case "long":
ctorTL.AppendLine($"{member.Name} = reader.ReadInt64();");
writeTl.AppendLine($"writer.Write({member.Name});");
break;
case "double":
ctorTL.AppendLine($"{member.Name} = reader.ReadDouble();");
writeTl.AppendLine($"writer.Write({member.Name});");
break;
case "bool":
ctorTL.AppendLine($"{member.Name} = reader.ReadTLBool();");
writeTl.AppendLine($"writer.Write({member.Name} ? 0x997275B5 : 0xBC799737);");
break;
case "System.DateTime":
ctorTL.AppendLine($"{member.Name} = reader.ReadTLStamp();");
writeTl.AppendLine($"writer.WriteTLStamp({member.Name});");
break;
case "string":
ctorTL.AppendLine($"{member.Name} = reader.ReadTLString();");
writeTl.AppendLine($"writer.WriteTLString({member.Name});");
break;
case "byte[]":
ctorTL.AppendLine($"{member.Name} = reader.ReadTLBytes();");
writeTl.AppendLine($"writer.WriteTLBytes({member.Name});");
break;
case "TL.Int128":
ctorTL.AppendLine($"{member.Name} = new Int128(reader);");
writeTl.AppendLine($"writer.Write({member.Name});");
break;
case "TL.Int256":
ctorTL.AppendLine($"{member.Name} = new Int256(reader);");
writeTl.AppendLine($"writer.Write({member.Name});");
break;
case "TL._Message[]":
ctorTL.AppendLine($"throw new NotSupportedException();");
writeTl.AppendLine($"writer.WriteTLMessages({member.Name});");
break;
case "TL.IObject": case "TL.IMethod<X>":
ctorTL.AppendLine($"{member.Name} = {(memberType == "TL.IObject" ? "" : $"({memberType})")}reader.ReadTL();");
writeTl.AppendLine($"{member.Name}.WriteTL(writer);");
break;
case "System.Collections.Generic.Dictionary<long, TL.User>":
ctorTL.AppendLine($"{member.Name} = reader.ReadTLDictionary<User>();");
writeTl.AppendLine($"writer.WriteTLVector({member.Name}.Values.ToArray());");
break;
case "System.Collections.Generic.Dictionary<long, TL.ChatBase>":
ctorTL.AppendLine($"{member.Name} = reader.ReadTLDictionary<ChatBase>();");
writeTl.AppendLine($"writer.WriteTLVector({member.Name}.Values.ToArray());");
break;
default:
if (member.Type is IArrayTypeSymbol arrayType)
{
if (name is "FutureSalts")
ctorTL.AppendLine($"{member.Name} = reader.ReadTLRawVector<{memberType.Substring(0, memberType.Length - 2)}>(0x0949D9DC);");
else
ctorTL.AppendLine($"{member.Name} = reader.ReadTLVector<{memberType.Substring(0, memberType.Length - 2)}>();");
writeTl.AppendLine($"writer.WriteTLVector({member.Name});");
}
else if (member.Type.BaseType.SpecialType == SpecialType.System_Enum)
{
ctorTL.AppendLine($"{member.Name} = ({memberType})reader.ReadUInt32();");
writeTl.AppendLine($"writer.Write((uint){member.Name});");
}
else if (memberType.StartsWith("TL."))
{
ctorTL.AppendLine($"{member.Name} = ({memberType})reader.ReadTL();");
var nullStr = nullables.TryGetValue(memberType, out uint nullCtor) ? $"0x{nullCtor:X8}" : "Layer.NullCtor";
writeTl.AppendLine($"if ({member.Name} != null) {member.Name}.WriteTL(writer); else writer.Write({nullStr});");
}
else
writeTl.AppendLine($"Cannot serialize {memberType}");
break;
}
}
ctorTL.AppendLine("\t\t}");
writeTl.AppendLine("\t\t}");
ctorTL.Append(writeTl.ToString());
if (symbol.IsGenericType) name += "<X>";
classes[name] = ctorTL.ToString();
}
var source = new StringBuilder();
source
.AppendLine("using System;")
.AppendLine("using System.IO;")
.AppendLine("using System.Linq;")
.AppendLine("using TL;")
.AppendLine();
foreach (var nullable in nullables)
readTL.AppendLine($"\t\t\t0x{nullable.Value:X8} => null,");
readTL.AppendLine("\t\t\tvar ctorNb => throw new Exception($\"Cannot find type for ctor #{ctorNb:x}\")");
readTL.AppendLine("\t\t};");
namespaces["TL"]["Layer"] = readTL.ToString();
foreach (var namesp in namespaces)
{
source.Append("namespace ").AppendLine(namesp.Key).Append('{');
foreach (var method in namesp.Value)
source.AppendLine().Append("\tpartial class ").AppendLine(method.Key).AppendLine("\t{").Append(method.Value).AppendLine("\t}");
source.AppendLine("}").AppendLine();
}
string text = source.ToString();
Debug.Write(text);
context.AddSource("TL.Generated.cs", text);
}
private static Dictionary<string, uint> LoadNullables(INamedTypeSymbol layer)
{
var nullables = layer.GetMembers("Nullables").Single() as IFieldSymbol;
var initializer = nullables.DeclaringSyntaxReferences[0].GetSyntax().ToString();
var table = new Dictionary<string, uint>();
foreach (var line in initializer.Split('\n'))
{
int index = line.IndexOf("[typeof(");
if (index == -1) continue;
int index2 = line.IndexOf(')', index += 8);
string className = "TL." + line.Substring(index, index2 - index);
index = line.IndexOf("= 0x", index2);
if (index == -1) continue;
index2 = line.IndexOf(',', index += 4);
table[className] = uint.Parse(line.Substring(index, index2 - index), System.Globalization.NumberStyles.HexNumber);
}
return table;
}
}

View file

@ -0,0 +1,20 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>netstandard2.0</TargetFramework>
<IsRoslynComponent>true</IsRoslynComponent>
<EnforceExtendedAnalyzerRules>true</EnforceExtendedAnalyzerRules>
<IncludeBuildOutput>true</IncludeBuildOutput>
<EnableNETAnalyzers>True</EnableNETAnalyzers>
<LangVersion>latest</LangVersion>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Microsoft.CodeAnalysis.Analyzers" Version="3.3.4" PrivateAssets="all" />
<PackageReference Include="Microsoft.CodeAnalysis.CSharp" Version="4.4.0" PrivateAssets="all" />
<PackageReference Include="System.Collections.Immutable" Version="6.0.0" />
</ItemGroup>
<!--<ItemGroup>
<None Include="$(OutputPath)\$(AssemblyName).dll" Pack="true" PackagePath="analyzers/dotnet/cs" Visible="false" />
</ItemGroup>-->
</Project>

View file

@ -570,12 +570,11 @@ namespace WTelegram
if (peek == Layer.RpcErrorCtor)
result = reader.ReadTLObject(Layer.RpcErrorCtor);
else if (peek == Layer.GZipedCtor)
using (var gzipReader = new BinaryReader(new GZipStream(new MemoryStream(reader.ReadTLBytes()), CompressionMode.Decompress)))
result = gzipReader.ReadTLValue(rpc.type);
result = reader.ReadTLGzipped();
else
{
reader.BaseStream.Position -= 4;
result = reader.ReadTLValue(rpc.type);
result = reader.ReadTLVector(rpc.type);
}
}
if (rpc.type.IsEnum) result = Enum.ToObject(rpc.type, result);

View file

@ -4,12 +4,11 @@ using System.IO;
using System.Numerics;
using System.Reflection;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Threading;
using System.Threading.Tasks;
#if NET8_0_OR_GREATER
using System.Text.Json.Serialization;
using System.Text.Json.Serialization.Metadata;
[JsonSerializable(typeof(WTelegram.Session))]
internal partial class WTelegramContext : JsonSerializerContext { }
#endif
@ -24,9 +23,9 @@ namespace WTelegram
/// <summary>For serializing indented Json with fields included</summary>
public static readonly JsonSerializerOptions JsonOptions = new() { IncludeFields = true, WriteIndented = true,
#if NET8_0_OR_GREATER
TypeInfoResolver = JsonSerializer.IsReflectionEnabledByDefault ? new DefaultJsonTypeInfoResolver() : WTelegramContext.Default,
TypeInfoResolver = JsonSerializer.IsReflectionEnabledByDefault ? null : WTelegramContext.Default,
#endif
IgnoreReadOnlyProperties = true, DefaultIgnoreCondition = System.Text.Json.Serialization.JsonIgnoreCondition.WhenWritingNull };
IgnoreReadOnlyProperties = true, DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull };
private static readonly ConsoleColor[] LogLevelToColor = [ ConsoleColor.DarkGray, ConsoleColor.DarkCyan,
ConsoleColor.Cyan, ConsoleColor.Yellow, ConsoleColor.Red, ConsoleColor.Magenta, ConsoleColor.DarkBlue ];

View file

@ -88,6 +88,9 @@ namespace TL
}
return null;
}
#if MTPG
public override void WriteTL(System.IO.BinaryWriter writer) => throw new NotImplementedException();
#endif
}
/// <summary>Accumulate users/chats found in this structure in your dictionaries, ignoring <see href="https://core.telegram.org/api/min">Min constructors</see> when the full object is already stored</summary>

View file

@ -11,7 +11,11 @@ using System.Text;
namespace TL
{
#if MTPG
public interface IObject { void WriteTL(BinaryWriter writer); }
#else
public interface IObject { }
#endif
public interface IMethod<ReturnType> : IObject { }
public interface IPeerResolver { IPeerInfo UserOrChat(Peer peer); }
@ -39,6 +43,7 @@ namespace TL
public sealed partial class ReactorError : IObject
{
public Exception Exception;
public void WriteTL(BinaryWriter writer) => throw new NotSupportedException();
}
public static class Serialization
@ -46,6 +51,9 @@ namespace TL
public static void WriteTLObject<T>(this BinaryWriter writer, T obj) where T : IObject
{
if (obj == null) { writer.WriteTLNull(typeof(T)); return; }
#if MTPG
obj.WriteTL(writer);
#else
var type = obj.GetType();
var tlDef = type.GetCustomAttribute<TLDefAttribute>();
var ctorNb = tlDef.CtorNb;
@ -63,14 +71,16 @@ namespace TL
if (field.Name == "flags") flags = (uint)value;
else if (field.Name == "flags2") flags |= (ulong)(uint)value << 32;
}
#endif
}
public static IObject ReadTLObject(this BinaryReader reader, uint ctorNb = 0)
{
#if MTPG
return reader.ReadTL(ctorNb);
#else
if (ctorNb == 0) ctorNb = reader.ReadUInt32();
if (ctorNb == Layer.GZipedCtor)
using (var gzipReader = new BinaryReader(new GZipStream(new MemoryStream(reader.ReadTLBytes()), CompressionMode.Decompress)))
return ReadTLObject(gzipReader);
if (ctorNb == Layer.GZipedCtor) return reader.ReadTLGzipped();
if (!Layer.Table.TryGetValue(ctorNb, out var type))
throw new WTelegram.WTException($"Cannot find type for ctor #{ctorNb:x}");
if (type == null) return null; // nullable ctor (class meaning is associated with null)
@ -90,6 +100,7 @@ namespace TL
else if (field.Name == "flags2") flags |= (ulong)(uint)value << 32;
}
return (IObject)obj;
#endif
}
internal static void WriteTLValue(this BinaryWriter writer, object value, Type valueType)
@ -178,17 +189,6 @@ namespace TL
}
}
internal static void WriteTLVector(this BinaryWriter writer, Array array)
{
writer.Write(Layer.VectorCtor);
if (array == null) { writer.Write(0); return; }
int count = array.Length;
writer.Write(count);
var elementType = array.GetType().GetElementType();
for (int i = 0; i < count; i++)
writer.WriteTLValue(array.GetValue(i), elementType);
}
internal static void WriteTLMessages(this BinaryWriter writer, _Message[] messages)
{
writer.Write(messages.Length);
@ -209,6 +209,38 @@ namespace TL
}
}
internal static void WriteTLVector(this BinaryWriter writer, Array array)
{
writer.Write(Layer.VectorCtor);
if (array == null) { writer.Write(0); return; }
int count = array.Length;
writer.Write(count);
var elementType = array.GetType().GetElementType();
for (int i = 0; i < count; i++)
writer.WriteTLValue(array.GetValue(i), elementType);
}
internal static T[] ReadTLRawVector<T>(this BinaryReader reader, uint ctorNb)
{
int count = reader.ReadInt32();
var array = new T[count];
for (int i = 0; i < count; i++)
array[i] = (T)reader.ReadTLObject(ctorNb);
return array;
}
internal static T[] ReadTLVector<T>(this BinaryReader reader)
{
var elementType = typeof(T);
if (reader.ReadUInt32() is not Layer.VectorCtor and uint ctorNb)
throw new WTelegram.WTException($"Cannot deserialize {elementType.Name}[] with ctor #{ctorNb:x}");
int count = reader.ReadInt32();
var array = new T[count];
for (int i = 0; i < count; i++)
array[i] = (T)reader.ReadTLValue(elementType);
return array;
}
internal static Array ReadTLVector(this BinaryReader reader, Type type)
{
var elementType = type.GetElementType();
@ -240,14 +272,13 @@ namespace TL
internal static Dictionary<long, T> ReadTLDictionary<T>(this BinaryReader reader) where T : class, IPeerInfo
{
uint ctorNb = reader.ReadUInt32();
var elementType = typeof(T);
if (ctorNb != Layer.VectorCtor)
throw new WTelegram.WTException($"Cannot deserialize Vector<{elementType.Name}> with ctor #{ctorNb:x}");
throw new WTelegram.WTException($"Cannot deserialize Vector<{typeof(T).Name}> with ctor #{ctorNb:x}");
int count = reader.ReadInt32();
var dict = new Dictionary<long, T>(count);
for (int i = 0; i < count; i++)
{
var value = (T)reader.ReadTLValue(elementType);
var value = (T)reader.ReadTLObject();
dict[value.ID] = value is UserEmpty ? null : value;
}
return dict;
@ -317,6 +348,19 @@ namespace TL
writer.Write(0); // null arrays/strings are serialized as empty
}
internal static IObject ReadTLGzipped(this BinaryReader reader)
{
using var gzipReader = new BinaryReader(new GZipStream(new MemoryStream(reader.ReadTLBytes()), CompressionMode.Decompress));
return ReadTLObject(gzipReader);
}
internal static bool ReadTLBool(this BinaryReader reader) => reader.ReadUInt32() switch
{
0x997275b5 => true,
0xbc799737 => false,
var value => throw new WTelegram.WTException($"Invalid boolean value #{value:x}")
};
#if DEBUG
private static void ShouldntBeHere() => System.Diagnostics.Debugger.Break();
#else
@ -356,6 +400,9 @@ namespace TL
{
public Messages_AffectedMessages affected;
public override (long, int, int) GetMBox() => (0, affected.pts, affected.pts_count);
#if MTPG
public override void WriteTL(BinaryWriter writer) => throw new NotSupportedException();
#endif
}
// Below TL types are commented "parsed manually" from https://github.com/telegramdesktop/tdesktop/blob/dev/Telegram/Resources/tl/mtproto.tl

View file

@ -25,7 +25,7 @@
<PackageReadmeFile>README.md</PackageReadmeFile>
<PackageReleaseNotes>$(ReleaseNotes.Replace("|", "%0D%0A").Replace(" - ","%0D%0A- ").Replace(" ", "%0D%0A%0D%0A"))</PackageReleaseNotes>
<NoWarn>0419;1573;1591;NETSDK1138</NoWarn>
<DefineConstants>TRACE;OBFUSCATION</DefineConstants>
<DefineConstants>TRACE;OBFUSCATION;MTPG</DefineConstants>
<!--<IsAotCompatible Condition="$([MSBuild]::IsTargetFrameworkCompatible('$(TargetFramework)', 'net7.0'))">true</IsAotCompatible>-->
</PropertyGroup>
@ -42,6 +42,10 @@
<ItemGroup>
<PackageReference Include="Microsoft.SourceLink.GitHub" Version="1.1.1" PrivateAssets="All" />
</ItemGroup>-->
<ItemGroup Condition="$(DefineConstants.Contains('MTPG'))" >
<ProjectReference Include="..\generator\MTProtoGenerator.csproj" OutputItemType="Analyzer" />
</ItemGroup>
<ItemGroup Condition="'$(TargetFramework)' == 'netstandard2.0'">
<PackageReference Include="IndexRange" Version="1.0.2" />