async void EventHandler 中的异步调用导致死锁

async call in async void EventHandler leads to a deadlock

有没有办法在 OnConnect 中调用 SendAsync 而不会导致死锁?我没有使用 .Wait.Result,它仍然会导致死锁。

编辑:

实际问题是 SendAsync 被调用了两次(一次在 OnConnect,一次在 Main)。如果我在 Main 中的第二次调用之前放置一个 await Task.Delay(10000),它实际上工作得很好。我该如何解决?如果没有任务延迟,它基本上挂在 await tcs.Task.ConfigureAwait(false) 上,因为它被调用了两次并且 async void OnConnect 有点“即发即弃”,这意味着它不会等待第一个 SendAsync 完成,在它之前去接第二个电话。

// Program.cs
var client = new Client(key, secret);

await client.StartAsync().ConfigureAwait(false);

await Task.Delay(3000); // This line fixes it, but it's kinda fake fix

await client.SendAsync(request).ConfigureAwait(false);
await client.SendAsync(request2).ConfigureAwait(false);

Console.ReadLine();

// Client.cs
public class Client
{
    private static long _nextId;
    private readonly WebSocketClient _webSocket;
    private readonly ConcurrentDictionary<long, TaskCompletionSource<string>> _outstandingRequests = new();

    ...

    public event EventHandler<ConnectEventArgs>? Connected;
    public event EventHandler<MessageReceivedEventArgs>? MessageReceived;

    public ValueTask StartAsync()
    {
        _client.Connected += OnConnect;
        _client.MessageReceived += OnMessageReceived;

        return _webSocket.StartAsync();  // there is a long-running `Task.Run` inside it, which keeps the web socket connection and its pipelines open.
    }

    private async void OnConnect(object? sender, ConnectEventArgs e)
    {
        await AuthAsync(...); // the problematic line
    }

    private void OnMessageReceived(object? sender, MessageReceivedEventArgs e)
    {
        ... deserialization stuff

        if (_requests.TryRemove(response.Id, out var tcs))
        {
            tcs.TrySetResult(message);
        }
    }

    public ValueTask<TResponse?> SendAsync<TResponse>(JsonRpcRequest request)
    {
        var tcs = new TaskCompletionSource<string>(TaskCreationOptions.RunContinuationsAsynchronously);
        _requests.TryAdd(request.Id, tcs);
        return SendRequestAndWaitForResponseAsync();

        async ValueTask<TResponse?> SendRequestAndWaitForResponseAsync()
        {
            var message = JsonSerializer.Serialize(request);
            await _client.SendAsync(message).ConfigureAwait(false);
            var response = await tcs.Task.ConfigureAwait(false); // it hangs here (deadlock)

            return JsonSerializer.Deserialize<TResponse>(response);
        }
    }

    public ValueTask<JsonRpcResponse?> AuthAsync(JsonRpcRequest request)
    {
        return SendAsync<JsonRpcResponse>(request);
    }

    private static long NextId()
    {
        return Interlocked.Increment(ref _nextId);
    }
}
public sealed class WebSocketClient
{
    private readonly AsyncManualResetEvent _sendSemaphore = new(false); // Nito.AsyncEx
    private readonly WebSocketPipe _webSocket; // SignalR Web Socket Pipe

    ...

    public event EventHandler<ConnectEventArgs>? Connected;
    public event EventHandler<DisconnectEventArgs>? Disconnected;
    public event EventHandler<MessageReceivedEventArgs>? MessageReceived;

    public ValueTask StartAsync()
    {
        _ = Task.Run(async () =>
        {
            try
            {
                await CreatePolicy()
                    .ExecuteAsync(async () =>
                    {
                        await _webSocket.StartAsync(new Uri(_url), CancellationToken.None).ConfigureAwait(false);

                        Connected?.Invoke(this, new ConnectEventArgs());

                        _sendSemaphore.Set();

                        await ReceiveLoopAsync().ConfigureAwait(false);
                    })
                    .ConfigureAwait(false);
            }
            catch (Exception ex)
            {
                // Failed after all retries
                Disconnected?.Invoke(this, new DisconnectEventArgs(ex));
            }
        });

        return ValueTask.CompletedTask;
    }

    public async ValueTask SendAsync(string message)
    {
        await _sendSemaphore.WaitAsync().ConfigureAwait(false);

        var encoded = Encoding.UTF8.GetBytes(message);
        await _webSocket.Transport!.Output
            .WriteAsync(new ArraySegment<byte>(encoded, 0, encoded.Length), CancellationToken.None)
            .ConfigureAwait(false);
    }

    private IAsyncPolicy CreatePolicy()
    {
        var retryPolicy = Policy
            .Handle<WebSocketException>()
            .WaitAndRetryForeverAsync(_ => ReconnectInterval,
                (exception, retryCount, calculatedWaitDuration) =>
                {
                    _sendSemaphore.Reset();

                    Reconnecting?.Invoke(this, new ReconnectingEventArgs(exception, retryCount, calculatedWaitDuration));

                    return Task.CompletedTask;
                });

        return retryPolicy;
    }

    private async Task ReceiveLoopAsync()
    {
        while (true)
        {
            var result = await _webSocket.Transport!.Input.ReadAsync(CancellationToken.None).ConfigureAwait(false);
            var buffer = result.Buffer;

            ...
        }
    }
}

如评论中所述,那些网络套接字包装器使用的是 System.IO.Pipelines,这是不正确的。 System.IO.Pipelines 是一个字节流,所以它适用于 (non-web) 个套接字;网络套接字是消息流,因此 System.Threading.Channels 更合适。

你可以试试这样的东西,我刚打出来,还没有 运行:

public sealed class ChannelWebSocket : IDisposable
{
    private readonly WebSocket _webSocket;
    private readonly Channel<Message> _input;
    private readonly Channel<Message> _output;

    public ChannelWebSocket(WebSocket webSocket, Options options)
    {
        _webSocket = webSocket;
        _input = Channel.CreateBounded(new BoundedChannelOptions(options.InputCapacity)
        {
            FullMode = options.InputFullMode,
        }, options.InputMessageDropped);
        _output = Channel.CreateBounded(new BoundedChannelOptions(options.OutputCapacity)
        {
            FullMode = options.OutputFullMode,
        }, options.OutputMessageDropped);
    }

    public ChannelReader<Message> Input => _input.Reader;
    public ChannelWriter<Message> Output => _output.Writer;

    public void Dispose() => _webSocket.Dispose();

    public async void Start()
    {
        var inputTask = InputLoopAsync(default);
        var outputTask = OutputLoopAsync(default);

        var completedTask = await Task.WhenAny(inputTask, outputTask);

        if (completedTask.Exception != null)
        {
            try { await _webSocket.CloseAsync(WebSocketCloseStatus.InternalServerError, statusDescription: null, default); } catch { /* ignore */ }
            try { _input.Writer.Complete(completedTask.Exception); } catch { /* ignore */ }
            try { _output.Writer.Complete(completedTask.Exception); } catch { /* ignore */ }
        }
    }

    public sealed class Message
    {
        public WebSocketMessageType MessageType { get; set; }
        public OwnedMemorySequence<byte> Payload { get; set; } = null!;
    }

    private async Task InputLoopAsync(CancellationToken cancellationToken)
    {
        while (true)
        {
            var payload = new OwnedMemorySequence<byte>();
            var buffer = MemoryPool<byte>.Shared.Rent();

            ValueWebSocketReceiveResult result;
            do
            {
                result = await _webSocket.ReceiveAsync(buffer.Memory, cancellationToken);
                if (result.MessageType == WebSocketMessageType.Close)
                {
                    _input.Writer.Complete();
                    return;
                }

                payload.Append(buffer.Slice(0, result.Count));
            } while (!result.EndOfMessage);

            await _input.Writer.WriteAsync(new Message
            {
                MessageType = result.MessageType,
                Payload = payload,
            }, cancellationToken);
        }
    }

    private async Task OutputLoopAsync(CancellationToken cancellationToken)
    {
        await foreach (var message in _output.Reader.ReadAllAsync())
        {
            var sequence = message.Payload.ReadOnlySequence;
            if (sequence.IsEmpty)
                continue;

            while (!sequence.IsSingleSegment)
            {
                await _webSocket.SendAsync(sequence.First, message.MessageType, endOfMessage: false, cancellationToken);
                sequence = sequence.Slice(sequence.First.Length);
            }

            await _webSocket.SendAsync(sequence.First, message.MessageType, endOfMessage: true, cancellationToken);
            message.Payload.Dispose();
        }

        await _webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, statusDescription: null, cancellationToken);
    }

    public sealed class Options
    {
        public int InputCapacity { get; set; } = 16;
        public BoundedChannelFullMode InputFullMode { get; set; } = BoundedChannelFullMode.Wait;
        public Action<Message>? InputMessageDropped { get; set; }

        public int OutputCapacity { get; set; } = 16;
        public BoundedChannelFullMode OutputFullMode { get; set; } = BoundedChannelFullMode.Wait;
        public Action<Message>? OutputMessageDropped { get; set; }
    }
}

它使用这种类型来构建内存序列:

public sealed class MemorySequence<T>
{
    private MemorySegment? _head;
    private MemorySegment? _tail;

    public MemorySequence<T> Append(ReadOnlyMemory<T> buffer)
    {
        if (_tail == null)
            _head = _tail = new MemorySegment(buffer, runningIndex: 0);
        else
            _tail = _tail.Append(buffer);
        return this;
    }

    public ReadOnlySequence<T> ReadOnlySequence => CreateReadOnlySequence(0, _tail?.Memory.Length ?? 0);

    public ReadOnlySequence<T> CreateReadOnlySequence(int firstBufferStartIndex, int lastBufferEndIndex) =>
        _tail == null ? new ReadOnlySequence<T>(Array.Empty<T>()) :
        new ReadOnlySequence<T>(_head!, firstBufferStartIndex, _tail, lastBufferEndIndex);

    private sealed class MemorySegment : ReadOnlySequenceSegment<T>
    {
        public MemorySegment(ReadOnlyMemory<T> memory, long runningIndex)
        {
            Memory = memory;
            RunningIndex = runningIndex;
        }

        public MemorySegment Append(ReadOnlyMemory<T> nextMemory)
        {
            var next = new MemorySegment(nextMemory, RunningIndex + Memory.Length);
            Next = next;
            return next;
        }
    }
}

此类型用于构建自有内存序列:

public sealed class OwnedMemorySequence<T> : IDisposable
{
    private readonly CollectionDisposable _disposable = new();
    private readonly MemorySequence<T> _sequence = new();

    public OwnedMemorySequence<T> Append(IMemoryOwner<T> memoryOwner)
    {
        _disposable.Add(memoryOwner);
        _sequence.Append(memoryOwner.Memory);
        return this;
    }

    public ReadOnlySequence<T> ReadOnlySequence => _sequence.ReadOnlySequence;

    public ReadOnlySequence<T> CreateReadOnlySequence(int firstBufferStartIndex, int lastBufferEndIndex) =>
        _sequence.CreateReadOnlySequence(firstBufferStartIndex, lastBufferEndIndex);

    public void Dispose() => _disposable.Dispose();
}

这取决于我从 here:

窃取的拥有的内存跨度扩展方法
public static class MemoryOwnerSliceExtensions
{
    public static IMemoryOwner<T> Slice<T>(this IMemoryOwner<T> owner, int start, int length)
    {
        if (start == 0 && length == owner.Memory.Length)
            return owner;
        return new SliceOwner<T>(owner, start, length);
    }

    public static IMemoryOwner<T> Slice<T>(this IMemoryOwner<T> owner, int start)
    {
        if (start == 0)
            return owner;
        return new SliceOwner<T>(owner, start);
    }

    private sealed class SliceOwner<T> : IMemoryOwner<T>
    {
        private readonly IMemoryOwner<T> _owner;
        public Memory<T> Memory { get; }

        public SliceOwner(IMemoryOwner<T> owner, int start, int length)
        {
            _owner = owner;
            Memory = _owner.Memory.Slice(start, length);
        }

        public SliceOwner(IMemoryOwner<T> owner, int start)
        {
            _owner = owner;
            Memory = _owner.Memory[start..];
        }

        public void Dispose() => _owner.Dispose();
    }
}

此代码完全未经测试;使用风险自负。