如何实现高效的 WhenEach 流式传输任务结果的 IAsyncEnumerable?

How to implement an efficient WhenEach that streams an IAsyncEnumerable of task results?

我正在尝试使用 C# 8, and one method that seems particularly useful is a version of Task.WhenAll that returns an IAsyncEnumerable 提供的新工具更新我的工具集。此方法应在任务结果可用时立即流式传输,因此将其命名为 WhenAll 没有多大意义。 WhenEach听起来更合适。该方法的签名是:

public static IAsyncEnumerable<TResult> WhenEach<TResult>(Task<TResult>[] tasks);

这个方法可以这样使用:

var tasks = new Task<int>[]
{
    ProcessAsync(1, 300),
    ProcessAsync(2, 500),
    ProcessAsync(3, 400),
    ProcessAsync(4, 200),
    ProcessAsync(5, 100),
};

await foreach (int result in WhenEach(tasks))
{
    Console.WriteLine($"Processed: {result}");
}

static async Task<int> ProcessAsync(int result, int delay)
{
    await Task.Delay(delay);
    return result;
}

预期输出:

Processed: 5
Processed: 4
Processed: 1
Processed: 3
Processed: 2

我设法在循环中使用方法 Task.WhenAny 编写了一个基本实现,但是这种方法有一个问题:

public static async IAsyncEnumerable<TResult> WhenEach<TResult>(
    Task<TResult>[] tasks)
{
    var hashSet = new HashSet<Task<TResult>>(tasks);
    while (hashSet.Count > 0)
    {
        var task = await Task.WhenAny(hashSet).ConfigureAwait(false);
        yield return await task.ConfigureAwait(false);
        hashSet.Remove(task);
    }
}

问题出在性能上。 Task.WhenAnyimplementation 创建了所提供任务列表的防御副本,因此在循环中重复调用它会导致 O(n²) 计算复杂度。我天真的实现努力处理 10,000 个任务。在我的机器上,开销将近 10 秒。我希望该方法的性能几乎与内置 Task.WhenAll 一样,可以轻松处理数十万个任务。我怎样才能改进 WhenEach 方法以使其正常运行?

通过使用 this 文章中的代码,您可以实现以下内容:

public static Task<Task<T>>[] Interleaved<T>(IEnumerable<Task<T>> tasks)
{
   var inputTasks = tasks.ToList();

   var buckets = new TaskCompletionSource<Task<T>>[inputTasks.Count];
   var results = new Task<Task<T>>[buckets.Length];
   for (int i = 0; i < buckets.Length; i++)
   {
       buckets[i] = new TaskCompletionSource<Task<T>>();
       results[i] = buckets[i].Task;
   }

   int nextTaskIndex = -1;
   Action<Task<T>> continuation = completed =>
   {
       var bucket = buckets[Interlocked.Increment(ref nextTaskIndex)];
       bucket.TrySetResult(completed);
   };

   foreach (var inputTask in inputTasks)
       inputTask.ContinueWith(continuation, CancellationToken.None, TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default);

   return results;
}

然后将您的 WhenEach 更改为调用 Interleaved 代码

public static async IAsyncEnumerable<TResult> WhenEach<TResult>(Task<TResult>[] tasks)
{
    foreach (var bucket in Interleaved(tasks))
    {
        var t = await bucket;
        yield return await t;
    }
}

然后你可以照常打电话给你的WhenEach

await foreach (int result in WhenEach(tasks))
{
    Console.WriteLine($"Processed: {result}");
}

我对 10k 任务进行了一些基本的基准测试,在速度方面的表现提高了 5 倍。

您可以将 Channel 用作异步队列。每个任务完成后都可以写入通道。频道中的项目将 return 通过 ChannelReader.ReadAllAsync 作为 IAsyncEnumerable 编辑。

IAsyncEnumerable<T> ToAsyncEnumerable<T>(IEnumerable<Task<T>> inputTasks)
{
    var channel=Channel.CreateUnbounded<T>();
    var writer=channel.Writer;
    var continuations=inputTasks.Select(t=>t.ContinueWith(x=>
                                           writer.TryWrite(x.Result)));
    _ = Task.WhenAll(continuations)
            .ContinueWith(t=>writer.Complete(t.Exception));

    return channel.Reader.ReadAllAsync();
}

当所有任务完成时调用writer.Complete()关闭通道。

为了对此进行测试,此代码会生成延迟减少的任务。这应该 return 索引的顺序相反:

var tasks=Enumerable.Range(1,4)
                    .Select(async i=>
                    { 
                      await Task.Delay(300*(5-i));
                      return i;
                    });

await foreach(var i in Interleave(tasks))
{
     Console.WriteLine(i);

}

产生:

4
3
2
1

只是为了好玩,使用 System.Reactive and System.Interactive.Async:

public static async IAsyncEnumerable<TResult> WhenEach<TResult>(
    Task<TResult>[] tasks)
    => Observable.Merge(tasks.Select(t => t.ToObservable())).ToAsyncEnumerable()

我真的很喜欢 ,但仍然希望像 JohanP 的解决方案那样在发生异常时引发异常。

为了实现这一点,我们可以稍微修改一下,以尝试在任务失败时关闭延续中的通道:

public IAsyncEnumerable<T> ToAsyncEnumerable<T>(IEnumerable<Task<T>> inputTasks)
{
    if (inputTasks == null)
    {
        throw new ArgumentNullException(nameof(inputTasks), "Task list must not be null.");
    }

    var channel = Channel.CreateUnbounded<T>();
    var channelWriter = channel.Writer;
    var inputTaskContinuations = inputTasks.Select(inputTask => inputTask.ContinueWith(completedInputTask =>
    {
        // Check whether the task succeeded or not
        if (completedInputTask.Status == TaskStatus.RanToCompletion)
        {
            // Write the result to the channel on successful completion
            channelWriter.TryWrite(completedInputTask.Result);
        }
        else
        {
            // Complete the channel on failure to immediately communicate the failure to the caller and prevent additional results from being returned
            var taskException = completedInputTask.Exception?.InnerException ?? completedInputTask?.Exception;
            channelWriter.TryComplete(taskException);
        }
    }));

    // Ensure the writer is closed after the tasks are all complete, and propagate any exceptions from the continuations
    _ = Task.WhenAll(inputTaskContinuations).ContinueWith(completedInputTaskContinuationsTask => channelWriter.TryComplete(completedInputTaskContinuationsTask.Exception));

    // Return the async enumerator of the channel so results are yielded to the caller as they're available
    return channel.Reader.ReadAllAsync();
}

这样做的明显缺点是遇到的第一个错误将结束枚举并阻止返回任何其他可能成功的结果。这是我的用例可以接受的权衡,但可能不适用于其他人。

我正在为这个问题添加一个答案,因为有几个问题需要解决。

  1. 建议创建异步可枚举序列的方法应该有一个 CancellationToken 参数。这会在 await foreach 循环中启用 WithCancellation 配置。
  2. 建议当异步操作将延续附加到任务时,应在操作完成时清除这些延续。因此,如果 WhenEach 方法的调用者决定提前退出 await foreach 循环(使用 breakreturn 等),或者如果循环由于例外,我们不想留下一堆死的延续,附加到任务上。如果在循环中重复调用 WhenEach(例如,作为 Retry 功能的一部分),这一点尤其重要。

下面的实现解决了这两个问题。它基于 Channel<Task<TResult>>. Now the channels 已经成为 .NET 平台不可或缺的一部分,因此没有理由回避它们以支持更复杂的基于 TaskCompletionSource 的解决方案。

public async static IAsyncEnumerable<TResult> WhenEach<TResult>(
    Task<TResult>[] tasks,
    [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
    if (tasks == null) throw new ArgumentNullException(nameof(tasks));
    var channel = Channel.CreateUnbounded<Task<TResult>>();
    using var completionCts = new CancellationTokenSource();
    var continuations = new List<Task>(tasks.Length);
    try
    {
        int pendingCount = tasks.Length;
        foreach (var task in tasks)
        {
            if (task == null) throw new ArgumentException(
                $"The tasks argument included a null value.", nameof(tasks));
            continuations.Add(task.ContinueWith(t =>
            {
                bool accepted = channel.Writer.TryWrite(t);
                Debug.Assert(accepted);
                if (Interlocked.Decrement(ref pendingCount) == 0)
                    channel.Writer.Complete();
            }, completionCts.Token, TaskContinuationOptions.ExecuteSynchronously |
                TaskContinuationOptions.DenyChildAttach, TaskScheduler.Default));
        }

        await foreach (var task in channel.Reader.ReadAllAsync(cancellationToken)
            .ConfigureAwait(false))
        {
            yield return await task.ConfigureAwait(false);
            cancellationToken.ThrowIfCancellationRequested();
        }
    }
    finally
    {
        completionCts.Cancel();
        try { await Task.WhenAll(continuations).ConfigureAwait(false); }
        catch (OperationCanceledException) { } // Ignore
    }
}

finally 块负责取消附加的延续,并在退出前等待它们完成。

await foreach 循环中的 ThrowIfCancellationRequested 可能看起来多余,但实际上是必需的,因为 ReadAllAsync 方法的设计行为,解释为 here.


注意: finally 块中的 OperationCanceledException 被低效的 try/catch 块抑制。捕获异常 is expensive. A more efficient implementation would suppress the error by awaiting the continuations with a specialized SuppressException awaiter, like the one featured in this 答案,并特殊处理 IsCanceled 案例。就此答案而言,解决这种低效率问题可能有点矫枉过正。 WhenEach 方法不太可能在紧密循环中使用。