如何从异步方法 return AggregateException

How to return AggregateException from async method

我得到了一个像增强的 Task.WhenAll 一样工作的异步方法。完成所有任务需要大量任务和 returns。

public async Task MyWhenAll(Task[] tasks) {
    ...
    await Something();
    ...

    // all tasks are completed
    if (someTasksFailed)
        throw ??
}

我的问题是,当一个或多个任务失败时,我如何获得 return 看起来像从 Task.WhenAll 编辑的 return 任务的方法?

如果我收集异常并抛出一个 AggregateException 它将被包装在另一个 AggregateException 中。

编辑:完整示例

async Task Main() {
    try {
        Task.WhenAll(Throw(1), Throw(2)).Wait();
    }
    catch (Exception ex) {
        ex.Dump();
    }

    try {
        MyWhenAll(Throw(1), Throw(2)).Wait();
    }
    catch (Exception ex) {
        ex.Dump();
    }
}

public async Task MyWhenAll(Task t1, Task t2) {
    await Task.Delay(TimeSpan.FromMilliseconds(100));
    try {
        await Task.WhenAll(t1, t2);
    }
    catch {
        throw new AggregateException(new[] { t1.Exception, t2.Exception });
    }
}
public async Task Throw(int id) {
    await Task.Delay(TimeSpan.FromMilliseconds(100));
    throw new InvalidOperationException("Inner" + id);
}

对于 Task.WhenAll,例外是 AggregateException,有 2 个内部例外。

对于 MyWhenAll 例外是 AggregateException 有一个内部 AggregateException 有 2 个内部例外。

编辑:我为什么要这样做

我经常需要调用寻呼API:s,想限制同时连接数

实际的方法签名是

public static async Task<TResult[]> AsParallelAsync<TResult>(this IEnumerable<Task<TResult>> source, int maxParallel)
public static async Task<TResult[]> AsParallelUntilAsync<TResult>(this IEnumerable<Task<TResult>> source, int maxParallel, Func<Task<TResult>, bool> predicate)

表示我可以这样进行分页

var pagedRecords = await Enumerable.Range(1, int.MaxValue)
                                   .Select(x => GetRecordsAsync(pageSize: 1000, pageNumber: x)
                                   .AsParallelUntilAsync(maxParallel: 5, x => x.Result.Count < 1000);
var records = pagedRecords.SelectMany(x => x).ToList();

一切正常,聚合中的聚合只是一个小问题。

使用 TaskCompletionSource.

最外层的异常由 .Wait().Result 创建 - 这被记录为将存储在任务中的异常包装在 AggregateException 中(以保留其堆栈跟踪 - 这是在创建 ExceptionDispatchInfo 之前引入。

但是,Task 实际上可以包含很多异常。在这种情况下,.Wait().Result 将抛出包含多个 InnerExceptionsAggregateException。您可以通过 TaskCompletionSource.SetException(IEnumerable<Exception> exceptions).

访问此功能

所以你想创建你自己的AggregateException。在任务上设置多个例外,让 .Wait().Result 为您创建 AggregateException

所以:

var tcs = new TaskCompletionSource<object>();
tcs.SetException(new[] { t1.Exception, t2.Exception });
return tcs.Task;

当然,如果你再调用await MyWhenAll(..)MyWhenAll(..).GetAwaiter().GetResult(),那么它只会抛出第一个异常。这符合 Task.WhenAll.

的行为

这意味着您需要向上传递 tcs.Task 作为您方法的 return 值,这意味着您的方法不能是 async。您最终会做这样丑陋的事情(根据您的问题调整示例代码):

public static Task MyWhenAll(Task t1, Task t2)
{
    var tcs = new TaskCompletionSource<object>();
    var _ = Impl();
    return tcs.Task;

    async Task Impl()
    {
        await Task.Delay(10);
        try
        {
            await Task.WhenAll(t1, t2);
            tcs.SetResult(null);
        }
        catch
        {
            tcs.SetException(new[] { t1.Exception, t2.Exception });
        }
    }
}

不过,在这一点上,我会开始询问您为什么要这样做,以及为什么您不能使用 Task returned 来自 Task.WhenAll直接。

async 方法被设计为在返回的任务中每个最多设置一个异常,而不是多个。

这给您留下了两个选择,您可以不使用 async 方法开始,而是依赖其他方法来执行您的方法:

public Task MyWhenAll(Task t1, Task t2)
{
    return Task.Delay(TimeSpan.FromMilliseconds(100))
        .ContinueWith(_ => Task.WhenAll(t1, t2))
        .Unwrap();
}

如果你有一个更复杂的方法,如果不使用 await 将更难编写,那么你将需要解包嵌套的聚合异常,这很乏味,但不是太复杂,要做到:

    public static Task UnwrapAggregateException(this Task taskToUnwrap)
    {
        var tcs = new TaskCompletionSource<bool>();

        taskToUnwrap.ContinueWith(task =>
        {
            if (task.IsCanceled)
                tcs.SetCanceled();
            else if (task.IsFaulted)
            {
                if (task.Exception is AggregateException aggregateException)
                    tcs.SetException(Flatten(aggregateException));
                else
                    tcs.SetException(task.Exception);
            }
            else //successful
                tcs.SetResult(true);
        });

        IEnumerable<Exception> Flatten(AggregateException exception)
        {
            var stack = new Stack<AggregateException>();
            stack.Push(exception);
            while (stack.Any())
            {
                var next = stack.Pop();
                foreach (Exception inner in next.InnerExceptions)
                {
                    if (inner is AggregateException innerAggregate)
                        stack.Push(innerAggregate);
                    else
                        yield return inner;
                }
            }
        }

        return tcs.Task;
    }

我删除了我之前的答案,因为我找到了一个更简单的解决方案。此解决方案不涉及讨厌的 ContinueWith 方法或 TaskCompletionSource 类型。这个想法是 return 嵌套 Task<Task> 来自 local function, and Unwrap() 它来自外部容器函数。这是这个想法的基本概述:

public Task<T[]> GetAllAsync<T>()
{
    return LocalAsyncFunction().Unwrap();

    async Task<Task<T[]>> LocalAsyncFunction()
    {
        var tasks = new List<Task<T>>();
        // ...
        await SomethingAsync();
        // ...
        Task<T[]> whenAll = Task.WhenAll(tasks);
        return whenAll;
    }
}

GetAllAsync 方法不是 async. It delegates all the work to the LocalAsyncFunction, which is async, and then Unwraps the resulting nested task and returns it. The unwrapped task contains in its .Exception.InnerExceptions property all the exceptions of the tasks, because it is just a facade of the internal Task.WhenAll 任务。

让我们展示这个想法的更实际的实现。下面的 AsParallelUntilAsync 方法懒惰地枚举 source 序列并将它包含的项目投影到 Task<TResult>s,直到一个项目满足 predicate。它还限制了异步操作的并发性。困难在于枚举 IEnumerable<TSource> 也可能引发异常。在这种情况下,正确的行为是在传播枚举错误之前等待所有 运行 任务,并且 return 包含枚举错误和可能具有的所有任务错误的 AggregateException其间发生。这是如何完成的:

public static Task<TResult[]> AsParallelUntilAsync<TSource, TResult>(
    this IEnumerable<TSource> source, Func<TSource, Task<TResult>> action,
    Func<TSource, bool> predicate, int maxConcurrency)
{
    return Implementation().Unwrap();

    async Task<Task<TResult[]>> Implementation()
    {
        var tasks = new List<Task<TResult>>();

        async Task<TResult> EnumerateAsync()
        {
            var semaphore = new SemaphoreSlim(maxConcurrency, maxConcurrency);
            using var enumerator = source.GetEnumerator();
            while (true)
            {
                await semaphore.WaitAsync();
                if (!enumerator.MoveNext()) break;
                var item = enumerator.Current;
                if (predicate(item)) break;

                async Task<TResult> RunAndRelease(TSource item)
                {
                    try { return await action(item); }
                    finally { semaphore.Release(); }
                }

                tasks.Add(RunAndRelease(item));
            }
            return default; // A dummy value that will never be returned
        }

        Task<TResult> enumerateTask = EnumerateAsync();

        try
        {
            await enumerateTask; // Make sure that the enumeration succeeded
            Task<TResult[]> whenAll = Task.WhenAll(tasks);
            await whenAll; // Make sure that all the tasks succeeded
            return whenAll;
        }
        catch
        {
            // Return a faulted task that contains ALL the errors!
            return Task.WhenAll(tasks.Prepend(enumerateTask));
        }
    }
}