在使用 Task.WhenAll 和最大并行度实现并行任务调用时如何管理锁?
How do I manage locks when implementing parallel task invocations with Task.WhenAll and max degree of parallelism?
我想出了以下代码,它重复调用一个页面大小为 5 的数据库分页函数,并为页面中的每个项目并行执行一个函数,最大并发数为 4。看起来它的工作原理到目前为止,但我不确定是否需要使用锁定来包含 parallelInvocationTasks.Remove(completedTask);
行和 Task.WhenAll(parallelInvocationTasks.ToArray());
所以我需要在这里使用锁定吗?你看到任何其他改进吗?
这是代码
Program.cs
using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
namespace ConsoleApp1
{
class Program
{
private static async Task Main(string[] args)
{
Console.WriteLine("Starting");
Func<int, int, CancellationToken, Task<IList<string>>> getNextPageFunction = GetNextPageFromDatabase;
await getNextPageFunction.ForEachParallel(4, 5, new CancellationToken(), async (item) =>
{
Console.WriteLine($"{item} started");
//simulate processing
await Task.Delay(1000);
Console.WriteLine($"{item} ended");
});
Console.WriteLine("Done");
}
private static async Task<IList<string>> GetNextPageFromDatabase(
int offset,
int pageSize,
CancellationToken cancellationToken)
{
//simulate i/o and database paging
await Task.Delay(2000, cancellationToken);
var pageData = new List<string>();
//simulate just 4 pages
if (offset >= pageSize * 3)
{
return pageData;
}
for (var i = 1; i <= pageSize; i++)
{
string nextItem = $"Item {i + offset}";
pageData.Add(nextItem);
}
return pageData;
}
}
}
PagingExtensions.cs
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
namespace ConsoleApp1
{
public static class PagingExtensions
{
public static async Task<int> ForEachParallel<TItem>(
this Func<int, int, CancellationToken, Task<IList<TItem>>> getNextPageFunction,
int concurrency,
int pageSize,
CancellationToken cancellationToken,
Func<TItem, Task> forEachFunction)
{
var enumeratedCount = 0;
if (getNextPageFunction == null || forEachFunction == null)
{
return enumeratedCount;
}
var offset = 0;
using (var semaphore = new SemaphoreSlim(concurrency))
{
IList<Task> parallelInvocationTasks = new List<Task>();
IList<TItem> items;
do
{
items = await getNextPageFunction(offset, pageSize, cancellationToken) ?? new List<TItem>();
foreach (TItem item in items)
{
await semaphore.WaitAsync(cancellationToken);
Task forEachFunctionTask = Task.Factory.StartNew(async () =>
{
try
{
await forEachFunction(item);
}
finally
{
// ReSharper disable once AccessToDisposedClosure
// This is safe as long as Task.WhenAll is called before the using semaphore
// enclosure ends
semaphore.Release();
}
}, cancellationToken)
.Unwrap();
parallelInvocationTasks.Add(forEachFunctionTask);
#pragma warning disable 4014
forEachFunctionTask.ContinueWith((completedTask) =>
#pragma warning restore 4014
{
if (completedTask.Exception == null)
{
//Intention is to release completed tasks during enumeration as they complete
//so they can be GCed. This is to ensure the 'parallelInvocationTasks' list does not
//grow in an unmanaged manner resulting in a list holding multiple completed tasks
//unnecessarily consuming more memory with each added invocation task
//Thus the final Task.WhenAll call below will only need to await only faulted tasks
//causing it to throw an exception and/or a minimal list of incomplete tasks only
parallelInvocationTasks.Remove(completedTask);
}
}, cancellationToken);
enumeratedCount += 1;
}
offset += pageSize;
}
while (items.Count >= pageSize);
await Task.WhenAll(parallelInvocationTasks.ToArray());
}
return enumeratedCount;
}
}
}
好的,基于上面的评论和更多的研究,我得出了这个答案,它可以完成工作,而不必编写自定义代码来管理并发的所有复杂性。它使用来自 TPL DataFlow
的 ActionBlock
PagingExtensions.cs
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks.Dataflow;
namespace ConsoleApp1
{
public static class PagingExtensions
{
public delegate Task<IList<TItem>> GetNextPageDelegate<TItem>(
int offset,
int pageSize,
CancellationToken cancellationToken);
public static async Task<int> EnumerateParallel<TItem>(
this GetNextPageDelegate<TItem> getNextPageFunction,
int maxDegreeOfParallelism,
int pageSize,
CancellationToken cancellationToken,
Func<TItem, Task> forEachFunction)
{
var enumeratedCount = 0;
if (getNextPageFunction == null || forEachFunction == null)
{
return enumeratedCount;
}
var offset = 0;
var forEachFunctionBlock = new ActionBlock<TItem>(forEachFunction, new ExecutionDataflowBlockOptions
{
BoundedCapacity = pageSize > maxDegreeOfParallelism ? pageSize : maxDegreeOfParallelism,
EnsureOrdered = false,
MaxDegreeOfParallelism = maxDegreeOfParallelism,
CancellationToken = cancellationToken
});
IList<TItem> items;
do
{
items = await getNextPageFunction(offset, pageSize, cancellationToken) ?? new List<TItem>();
foreach (TItem item in items)
{
await forEachFunctionBlock.SendAsync(item, cancellationToken);
enumeratedCount += 1;
}
offset += pageSize;
}
while (items.Count >= pageSize);
forEachFunctionBlock.Complete();
await forEachFunctionBlock.Completion;
return enumeratedCount;
}
}
}
我想出了以下代码,它重复调用一个页面大小为 5 的数据库分页函数,并为页面中的每个项目并行执行一个函数,最大并发数为 4。看起来它的工作原理到目前为止,但我不确定是否需要使用锁定来包含 parallelInvocationTasks.Remove(completedTask);
行和 Task.WhenAll(parallelInvocationTasks.ToArray());
所以我需要在这里使用锁定吗?你看到任何其他改进吗?
这是代码
Program.cs
using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
namespace ConsoleApp1
{
class Program
{
private static async Task Main(string[] args)
{
Console.WriteLine("Starting");
Func<int, int, CancellationToken, Task<IList<string>>> getNextPageFunction = GetNextPageFromDatabase;
await getNextPageFunction.ForEachParallel(4, 5, new CancellationToken(), async (item) =>
{
Console.WriteLine($"{item} started");
//simulate processing
await Task.Delay(1000);
Console.WriteLine($"{item} ended");
});
Console.WriteLine("Done");
}
private static async Task<IList<string>> GetNextPageFromDatabase(
int offset,
int pageSize,
CancellationToken cancellationToken)
{
//simulate i/o and database paging
await Task.Delay(2000, cancellationToken);
var pageData = new List<string>();
//simulate just 4 pages
if (offset >= pageSize * 3)
{
return pageData;
}
for (var i = 1; i <= pageSize; i++)
{
string nextItem = $"Item {i + offset}";
pageData.Add(nextItem);
}
return pageData;
}
}
}
PagingExtensions.cs
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
namespace ConsoleApp1
{
public static class PagingExtensions
{
public static async Task<int> ForEachParallel<TItem>(
this Func<int, int, CancellationToken, Task<IList<TItem>>> getNextPageFunction,
int concurrency,
int pageSize,
CancellationToken cancellationToken,
Func<TItem, Task> forEachFunction)
{
var enumeratedCount = 0;
if (getNextPageFunction == null || forEachFunction == null)
{
return enumeratedCount;
}
var offset = 0;
using (var semaphore = new SemaphoreSlim(concurrency))
{
IList<Task> parallelInvocationTasks = new List<Task>();
IList<TItem> items;
do
{
items = await getNextPageFunction(offset, pageSize, cancellationToken) ?? new List<TItem>();
foreach (TItem item in items)
{
await semaphore.WaitAsync(cancellationToken);
Task forEachFunctionTask = Task.Factory.StartNew(async () =>
{
try
{
await forEachFunction(item);
}
finally
{
// ReSharper disable once AccessToDisposedClosure
// This is safe as long as Task.WhenAll is called before the using semaphore
// enclosure ends
semaphore.Release();
}
}, cancellationToken)
.Unwrap();
parallelInvocationTasks.Add(forEachFunctionTask);
#pragma warning disable 4014
forEachFunctionTask.ContinueWith((completedTask) =>
#pragma warning restore 4014
{
if (completedTask.Exception == null)
{
//Intention is to release completed tasks during enumeration as they complete
//so they can be GCed. This is to ensure the 'parallelInvocationTasks' list does not
//grow in an unmanaged manner resulting in a list holding multiple completed tasks
//unnecessarily consuming more memory with each added invocation task
//Thus the final Task.WhenAll call below will only need to await only faulted tasks
//causing it to throw an exception and/or a minimal list of incomplete tasks only
parallelInvocationTasks.Remove(completedTask);
}
}, cancellationToken);
enumeratedCount += 1;
}
offset += pageSize;
}
while (items.Count >= pageSize);
await Task.WhenAll(parallelInvocationTasks.ToArray());
}
return enumeratedCount;
}
}
}
好的,基于上面的评论和更多的研究,我得出了这个答案,它可以完成工作,而不必编写自定义代码来管理并发的所有复杂性。它使用来自 TPL DataFlow
的 ActionBlockPagingExtensions.cs
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks.Dataflow;
namespace ConsoleApp1
{
public static class PagingExtensions
{
public delegate Task<IList<TItem>> GetNextPageDelegate<TItem>(
int offset,
int pageSize,
CancellationToken cancellationToken);
public static async Task<int> EnumerateParallel<TItem>(
this GetNextPageDelegate<TItem> getNextPageFunction,
int maxDegreeOfParallelism,
int pageSize,
CancellationToken cancellationToken,
Func<TItem, Task> forEachFunction)
{
var enumeratedCount = 0;
if (getNextPageFunction == null || forEachFunction == null)
{
return enumeratedCount;
}
var offset = 0;
var forEachFunctionBlock = new ActionBlock<TItem>(forEachFunction, new ExecutionDataflowBlockOptions
{
BoundedCapacity = pageSize > maxDegreeOfParallelism ? pageSize : maxDegreeOfParallelism,
EnsureOrdered = false,
MaxDegreeOfParallelism = maxDegreeOfParallelism,
CancellationToken = cancellationToken
});
IList<TItem> items;
do
{
items = await getNextPageFunction(offset, pageSize, cancellationToken) ?? new List<TItem>();
foreach (TItem item in items)
{
await forEachFunctionBlock.SendAsync(item, cancellationToken);
enumeratedCount += 1;
}
offset += pageSize;
}
while (items.Count >= pageSize);
forEachFunctionBlock.Complete();
await forEachFunctionBlock.Completion;
return enumeratedCount;
}
}
}