在两个流之间代理 WebSocket 消息

Proxying WebSocket messages between two streams

我有一个充当中间人的 HTTP 代理服务器。它基本上执行以下操作:

所以基本上有一个 NetworkStream,或者更常见的是 SslStream 在客户端浏览器和代理之间,另一个在代理和服务器之间。

出现了在客户端和服务器之间转发 WebSocket 流量的要求。

因此,现在当客户端浏览器请求连接升级到 websocket,并且远程服务器以 HTTP 代码 101 响应时,代理服务器会维护这些连接,以便将更多消息从客户端转发到服务器,反之亦然。

因此,在代理收到来自远程服务器的消息说它已准备好切换协议后,它需要进入一个循环,在该循环中轮询客户端和服务器流的数据,并将任何接收到的数据转发到对方.

问题

WebSocket允许双方随时发送消息。对于 ping/pong 之类的控制消息,这尤其是一个问题,其中任何一方都可以随时发送 ping,而另一方则应回复 pong 及时。现在考虑有两个没有 DataAvailable 属性 的 SslStream 实例,其中读取数据的唯一方法是调用 Read/ReadAsync,这可能在某些数据可用之前,不会 return。考虑以下伪代码:

public async Task GetMessage()
{
    // All these methods that we await read from the source stream
    byte[] firstByte = await GetFirstByte(); // 1-byte buffer
    byte[] messageLengthBytes = await GetMessageLengthBytes();
    uint messageLength = GetMessageLength(messageLengthBytes);
    bool isMessageMasked = DetermineIfMessageMasked(messageLengthBytes);
    byte[] maskBytes;
    if (isMessageMasked)
    {
        maskBytes = await GetMaskBytes();
    }

    byte[] messagePayload = await GetMessagePayload(messageLength);

    // This method writes to the destination stream
    await ComposeAndForwardMessageToOtherParty(firstByte, messageLengthBytes, maskBytes, messagePayload);
}

上面的伪代码从一个流中读取并写入另一个流。问题是上述过程需要同时对两个流进行 运行,因为我们不知道在任何给定时间点哪一方会向另一方发送消息。然而,当读操作处于活动状态时,不可能执行写操作。而且因为我们没有轮询传入数据所需的方法,所以读取操作必须是阻塞的。这意味着如果我们同时开始对两个流进行读取操作,我们就可以忘记写入它们。一个流最终会 return 一些数据,但我们无法将该数据发送到另一个流,因为它仍然忙于尝试读取。这可能需要一段时间,至少在拥有该流的一方发送 ping 请求之前。

感谢@MarcGravell 的评论,我们了解到网络流支持独立 read/write 操作,即 NetworkStream 充当两个 独立 管道- 一读一写 - 全双工。

因此,代理 WebSocket 消息可以像启动两个独立任务一样简单,一个从客户端流读取并写入服务器流,另一个从服务器流读取并写入客户端流。

如果它对任何搜索它的人有帮助,我是这样实现的:

public class WebSocketRequestHandler
{
    private const int MaxMessageLength = 0x7FFFFFFF;

    private const byte LengthBitMask = 0x7F;

    private const byte MaskBitMask = 0x80;

    private delegate Task WriteStreamAsyncDelegate(byte[] buffer, int offset, int count, CancellationToken cancellationToken);

    private delegate Task<byte[]> BufferStreamAsyncDelegate(int count, CancellationToken cancellationToken);

    public async Task HandleWebSocketMessagesAsync(CancellationToken cancellationToken = default(CancellationToken))
    {
        var clientListener = ListenForClientMessages(cancellationToken);
        var serverListener = ListenForServerMessages(cancellationToken);
        await Task.WhenAll(clientListener, serverListener);
    }

    private async Task ListenForClientMessages(CancellationToken cancellationToken)
    {
        while (!cancellationToken.IsCancellationRequested)
        {
            cancellationToken.ThrowIfCancellationRequested();
            await ListenForMessages(YOUR_CLIENT_STREAM_BUFFER_METHOD_DELEGATE, YOUR_SERVER_STREAM_WRITE_METHOD_DELEGATE, cancellationToken);
        }
    }

    private async Task ListenForServerMessages(CancellationToken cancellationToken)
    {
        while (!cancellationToken.IsCancellationRequested)
        {
            cancellationToken.ThrowIfCancellationRequested();
            await ListenForMessages(YOUR_SERVER_STREAM_BUFFER_METHOD_DELEGATE, YOUR_CLIENT_STREAM_WRITE_METHOD_DELEGATE, cancellationToken);
        }
    }

    private static async Task ListenForMessages(BufferStreamAsyncDelegate sourceStreamReader,
        WriteStreamAsyncDelegate destinationStreamWriter,
        CancellationToken cancellationToken)
    {
        var messageBuilder = new List<byte>();
        var firstByte = await sourceStreamReader(1, cancellationToken);
        messageBuilder.AddRange(firstByte);
        var lengthBytes = await GetLengthBytes(sourceStreamReader, cancellationToken);
        messageBuilder.AddRange(lengthBytes);
        var isMaskBitSet = (lengthBytes[0] & MaskBitMask) != 0;
        var length = GetMessageLength(lengthBytes);
        if (isMaskBitSet)
        {
            var maskBytes = await sourceStreamReader(4, cancellationToken);
            messageBuilder.AddRange(maskBytes);
        }

        var messagePayloadBytes = await sourceStreamReader(length, cancellationToken);
        messageBuilder.AddRange(messagePayloadBytes);
        await destinationStreamWriter(messageBuilder.ToArray(), 0, messageBuilder.Count, cancellationToken);
    }

    private static async Task<byte[]> GetLengthBytes(BufferStreamAsyncDelegate sourceStreamReader, CancellationToken cancellationToken)
    {
        var lengthBytes = new List<byte>();
        var firstLengthByte = await sourceStreamReader(1, cancellationToken);
        lengthBytes.AddRange(firstLengthByte);
        var lengthByteValue = firstLengthByte[0] & LengthBitMask;
        if (lengthByteValue <= 125)
        {
            return lengthBytes.ToArray();
        }

        switch (lengthByteValue)
        {
            case 126:
            {
                var secondLengthBytes = await sourceStreamReader(2, cancellationToken);
                lengthBytes.AddRange(secondLengthBytes);
                return lengthBytes.ToArray();
            }
            case 127:
            {
                var secondLengthBytes = await sourceStreamReader(8, cancellationToken);
                lengthBytes.AddRange(secondLengthBytes);
                return lengthBytes.ToArray();
            }
            default:
                throw new Exception($"Unexpected first length byte value: {lengthByteValue}");
        }
    }

    private static int GetMessageLength(byte[] lengthBytes)
    {
        byte[] subArray;
        switch (lengthBytes.Length)
        {
            case 1:
                return lengthBytes[0] & LengthBitMask;

            case 3:
                if (!BitConverter.IsLittleEndian)
                {
                    return BitConverter.ToUInt16(lengthBytes, 1);
                }

                subArray = lengthBytes.SubArray(1, 2);
                Array.Reverse(subArray);
                return BitConverter.ToUInt16(subArray, 0);

            case 9:
                subArray = lengthBytes.SubArray(1, 8);
                Array.Reverse(subArray);
                var retVal = BitConverter.ToUInt64(subArray, 0);
                if (retVal > MaxMessageLength)
                {
                    throw new Exception($"Unexpected payload length: {retVal}");
                }

                return (int) retVal;

            default:
                throw new Exception($"Impossibru!!1 The length of lengthBytes array was: '{lengthBytes.Length}'");
        }
    }
}

在执行初始握手后,只需调用 await handler.HandleWebSocketMessagesAsync(cancellationToken) 即可使用。

SubArray方法取自这里:(也来自@Marc哈哈)