C# 9.0 源代码生成器依赖注入注册不支持异步方法

C# 9.0 source generator dependency injection registration not supporting async methods

问题:

尝试实现一个自动依赖注入注册器,我的约定非常严格,所以它会非常有用。

我在注册包含异步方法的 classes 时遇到问题,容器似乎在注册 class.

时处理这些方法

关于项目:

TL;DL

一些复制错误:

代码:

源生成器项目:

namespace Test.Build.Tools
{
    using System.Collections.Generic;
    using System.Linq;
    using System.Text;
    using Microsoft.CodeAnalysis;
    using Microsoft.CodeAnalysis.Text;

    /// <summary>
    /// Auto register source generator.
    /// </summary>
    [Generator]
    public class AutoRegisterSourceGenerator : ISourceGenerator
    {
        /// <inheritdoc/>
        public void Initialize(GeneratorInitializationContext context)
        {
        }

        /// <inheritdoc/>
        public void Execute(GeneratorExecutionContext context)
        {
            StringBuilder stringBuilder = new("namespace Test.Extensions.DependencyInjection\n"
                                            + "{\n"
                                            + "    using System;\n"
                                            + "    using System.Threading.Tasks;\n"
                                            + "    using Microsoft.Extensions.DependencyInjection;\n");
            List<string> namespaces = new();

            string defaultPath = typeof(object).Assembly.Location.Replace("mscorlib", "{0}");

            List<MetadataReference> references = new()
            {
                { MetadataReference.CreateFromFile(string.Format(defaultPath, "System.Threading.Tasks")) }
            };

            var types = GetAllTypes(context.Compilation);
            var neededTypes = types.Where(t =>
            {
                string @namespace = t.ContainingNamespace.ToString();

                if (@namespace.Contains("Test")
                && !t.Interfaces.IsEmpty
                && t.TypeKind == TypeKind.Class)
                {
                    namespaces.Add(t.ContainingNamespace.ToString());
                    namespaces.Add(t.Interfaces[0].ContainingNamespace.ToString());
                    return true;
                }

                return false;
            }).ToList();

            namespaces.Distinct().OrderBy(n => n.ToString()).ToList().ForEach(n => stringBuilder.Append($"    using {n};\n"));

            stringBuilder.Append(
                "    /// <summary>\n" +
                "    /// Service registrator class.\n" +
                "    /// </summary>\n" +
                "    public static class ServicesRegistrator\n" +
                "    {\n" +
                "        /// <summary>\n" +
                "        /// Register dependency injection instances.\n" +
                "        /// </summary>\n" +
                "        /// <param name=\"services\">Startup services.</param>\n" +
                "        /// <returns>The given <see cref=\"IServiceCollection\"/> instance.</returns>\n" +
                "        public static IServiceCollection RegisterDomainModel(this IServiceCollection services)\n" +
                "        {\n");

            foreach (var type in neededTypes)
            {
                stringBuilder.Append($"            services.AddScoped<I{type.Name}, {type.Name}>();");
                stringBuilder.AppendLine();
            }

            stringBuilder.Append("            return services;\n" +
                "        }\n" +
                "    }\n" +
                "}\n");

            context.Compilation.AddReferences(references);

            context.AddSource("ServicesRegistrator", SourceText.From(stringBuilder.ToString(), Encoding.UTF8));
        }

        IEnumerable<INamedTypeSymbol> GetAllTypes(Compilation compilation) =>
            GetAllTypes(compilation.GlobalNamespace);

        IEnumerable<INamedTypeSymbol> GetAllTypes(INamespaceSymbol @namespace)
        {
            foreach (var type in @namespace.GetTypeMembers())
                foreach (var nestedType in GetNestedTypes(type))
                    yield return nestedType;

            foreach (var nestedNamespace in @namespace.GetNamespaceMembers())
                foreach (var type in GetAllTypes(nestedNamespace))
                    yield return type;
        }

        IEnumerable<INamedTypeSymbol> GetNestedTypes(INamedTypeSymbol type)
        {
            yield return type;
            foreach (var nestedType in type.GetTypeMembers()
                .SelectMany(nestedType => GetNestedTypes(nestedType)))
                yield return nestedType;
        }
    }
}

模型项目:

namespace TestClasses
{
    using System.Threading.Tasks;

    public interface ITestClass
    {
        public Task TestMethod();
    }
}

namespace TestClasses.Model
{
    using System.Threading.Tasks;

    public class TestClass : ITestClass
    {
        public async Task TestMethod()
        {
            await Task.CompletedTask;
        }
    }
}

可执行文件

using Executable;

Program.Rgister();

namespace Executable
{
    using Microsoft.Extensions.DependencyInjection;
    using Test.Extensions.DependencyInjection;
    using TestClasses;

    public class Program
    {
        public static void Rgister()
        {
            IServiceCollection services = new ServiceCollection();
            services.RegisterDomainModel();

            var x = services.BuildServiceProvider().GetRequiredService<ITestClass>();

            x.TestMethod();
        }
    }
}

更新:

生成的代码:

namespace Test.Extensions.DependencyInjection
{
    using System;
    using System.Threading.Tasks;
    using Microsoft.Extensions.DependencyInjection;
    using TestClasses;
    using TestClasses.Model;
    /// <summary>
    /// Service registrator class.
    /// </summary>

    public static class ServicesRegistrator
    {
        /// <summary>
        /// Register dependency injection instances.
        /// </summary>
        /// <param name="services">Startup services.</param>
        /// <returns>The given <see cref="IServiceCollection"/> instance.</returns>
        public static IServiceCollection RegisterDomainModel(this IServiceCollection services)
        {
            services.AddScoped<ITestClass, TestClass>();
            return services;
        }
    }
}

async/await 它是由编译器解释的糖语法。编译后,async 方法被生成的 class 替换。您可以使用 ILSpy 之类的工具进行检查(在 ILSpy 中,打开“View\Show 所有类型和成员”。

根据您的模型代码,我们可以看到生成的 DLL 包含此 class :

// TestClasses.Model.TestClass.<TestMethod>d__0
using System;
using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Threading.Tasks;

[CompilerGenerated]
private sealed class <TestMethod>d__0 : IAsyncStateMachine
{
    ...
}

许多 sugar 关键字(如 yield)在编译后会产生这种类型的 class。在你的生成器中,你需要忽略这个 classes。为此,您必须检查 class 是否具有属性 CompilerGeneratedAttribute.

也许可以在生成 classes 之前生成注入代码。