Winsock2 在环回网络上丢弃消息

Winsock2 drops messages on the loopback network

我制作了一个服务器套接字,我想在多线程模式中允许多个连接,但是有一个问题。它会无缘无故地丢弃来自客户端的消息

每个套接字都由自己的线程处理,所以我猜这应该不是问题(可能是)。

这是代码

#define WIN32_LEAN_AND_MEAN
#define NOMINMAX
#include <windows.h>
#include <WinSock2.h>
#include <WS2tcpip.h>

#include <iostream>
#include <thread>
#include <mutex>
#include <chrono>
#include <vector>
#include <sstream>
#include <cassert>

// Taken from: 
static std::string wsa_error_to_string(int wsa_error)
{
    char msgbuf [256];   // for a message up to 255 bytes.
    msgbuf [0] = '[=10=]';    // Microsoft doesn't guarantee this on man page.

    FormatMessage(
        FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, // flags
        nullptr,                                                    // lpsource
        wsa_error,                                                  // message id
        MAKELANGID (LANG_NEUTRAL, SUBLANG_DEFAULT),                 // languageid
        msgbuf,                                                     // output buffer
        sizeof (msgbuf),                                            // size of msgbuf, bytes
        nullptr
    );

    if (! *msgbuf)
        sprintf (msgbuf, "%d", wsa_error);  // provide error # if no string available
    return msgbuf;
}

#define PRINT_ERROR_AND_TERMINATE(MSG) do { std::cerr << (MSG) << std::endl; assert(0); } while(0)

struct wsa_lifetime
{
    wsa_lifetime()
    {
        int result = ::WSAStartup(MAKEWORD(2,2), &wsa_data);
        assert(result == 0);
        is_initialized = true;
    }

    ~wsa_lifetime()
    {
        ::WSACleanup();
    }

    WSAData wsa_data {};
    bool is_initialized {false};
};

static wsa_lifetime wsa_lifetime;

static SOCKET socket_create()
{
    SOCKET socket = ::socket(AF_INET, SOCK_STREAM, 0);
    assert(socket != INVALID_SOCKET);
    return socket;
}

static void socket_destroy(SOCKET socket)
{
    ::closesocket(socket);
    socket = INVALID_SOCKET;
}

static void socket_bind(SOCKET socket, const char *address, uint16_t port)
{
    sockaddr_in addr {};
    addr.sin_family = AF_INET;
    inet_pton(AF_INET, address, &addr.sin_addr.s_addr);
    addr.sin_port = htons(port);

    int bind_result = ::bind(socket, reinterpret_cast<SOCKADDR *>(&addr), sizeof(addr));
    if (bind_result == SOCKET_ERROR)
        PRINT_ERROR_AND_TERMINATE(WSAGetLastError());
}

static void socket_connect(SOCKET socket, const char *address, uint16_t port)
{
    sockaddr_in addr {};
    addr.sin_family = AF_INET;
    inet_pton(AF_INET, address, &addr.sin_addr.s_addr);
    addr.sin_port = htons(port);

    int connect_result = ::connect(socket, reinterpret_cast<SOCKADDR *>(&addr), sizeof(addr));
    if (connect_result == SOCKET_ERROR)
        PRINT_ERROR_AND_TERMINATE(WSAGetLastError());
}

static void socket_listen(SOCKET socket)
{
    int listen_result = ::listen(socket, SOMAXCONN);
    if (listen_result == SOCKET_ERROR)
        PRINT_ERROR_AND_TERMINATE(WSAGetLastError());
}

static SOCKET socket_accept(SOCKET socket)
{
    SOCKET accepted_socket = ::accept(socket, nullptr, nullptr);
    if (accepted_socket == INVALID_SOCKET)
        PRINT_ERROR_AND_TERMINATE(WSAGetLastError());
    return accepted_socket;
}

static size_t socket_recv(SOCKET socket, char *buffer, size_t buffer_size, int flags = 0)
{
    int bytes_received = ::recv(socket, buffer, static_cast<int>(buffer_size), flags);
    if (bytes_received == SOCKET_ERROR)
    {
        int err = WSAGetLastError();
        if (err == WSAECONNRESET)
            return 0; // Disconnected client
        PRINT_ERROR_AND_TERMINATE(WSAGetLastError());
    }
    return bytes_received;
}

static size_t socket_send(SOCKET socket, const char *data, size_t data_size, int flags = 0)
{
    int bytes_sent = ::send(socket, data, static_cast<int>(data_size), flags);
    if (bytes_sent == SOCKET_ERROR)
    {
        int err = WSAGetLastError();
        if (err == WSAECONNRESET)
            return 0; // Disconnected client
        PRINT_ERROR_AND_TERMINATE(WSAGetLastError());
    }
    return bytes_sent;
}

static std::mutex output_mutex;

int main()
{
    const char *server_address = "127.0.0.1";
    uint16_t server_port = 23456;
    bool server_terminate = false;

    std::thread server_thread([server_address, server_port, &server_terminate](){
        SOCKET server = socket_create();
        socket_bind(server, server_address, server_port);
        socket_listen(server);

        std::vector<SOCKET> clients;
        std::vector<std::thread> client_threads;

        while (!server_terminate)
        {
            SOCKET incoming_client = socket_accept(server);
            if (server_terminate)
                break;

            clients.push_back(incoming_client);
            size_t client_id = clients.size();

            std::thread incoming_client_thread([&incoming_client, client_id](){
                const size_t data_size = 1024;
                char data[data_size];

                while (true)
                {
                    size_t bytes_received = socket_recv(incoming_client, data, data_size, 0);
                    if (bytes_received == 0)
                        break;

                    std::string_view client_message(data, bytes_received);
                    {
                        std::unique_lock lock(output_mutex);
                        std::cout << "Client (" << client_id << "): " << client_message << std::endl;
                    }
                }
            });
            client_threads.push_back(std::move(incoming_client_thread));
        }

        for (std::thread &client_thread: client_threads)
            if (client_thread.joinable())
                client_thread.join();
    });

    std::vector<SOCKET> clients;
    std::vector<std::thread> client_threads;

    for (int i = 0; i < 4; i++)
    {
        SOCKET client = socket_create();
        clients.push_back(client);
    }

    for (SOCKET client : clients)
    {
        std::thread client_thread([server_address, server_port, client](){
            socket_connect(client, server_address, server_port);

            for (int i = 0; i < 10; i++)
            {
                std::string data_str = (std::stringstream() << "hello " << i).str();
                socket_send(client, data_str.c_str(), data_str.size());

                using namespace std::chrono_literals;
                std::this_thread::sleep_for(100ms + 1ms * (rand() % 100));
            }
        });
        client_threads.push_back(std::move(client_thread));
    }

    for (std::thread &client_thread : client_threads)
        if (client_thread.joinable())
            client_thread.join();

    for (SOCKET client: clients)
        socket_destroy(client);
    clients.clear();

    server_terminate = true;

    SOCKET dummy_socket = socket_create();
    socket_connect(dummy_socket, server_address, server_port); // just to unblock server's socket_accept() blocking call
    socket_destroy(dummy_socket);

    if (server_thread.joinable())
        server_thread.join();

    return 0;
}

可能的输出:

Client (2): hello 0
Client (2): hello 0
Client (3): hello 1
Client (2): hello 2
Client (1): hello 3
Client (4): hello 4
Client (3): hello 5
Client (2): hello 6
Client (1): hello 7
Client (4): hello 8
Client (3): hello 9

我预计每个客户端发送 10 条消息,总共 40 条,但如您所见,有些消息被丢弃了。我认为即使使用 UDP 传输也不应该丢失,因为所有工作都在我的环回网络上完成
Wireshark 注册所有消息

在构造 lambda incoming_client_thread 时,您通过引用而不是复制来捕获 incoming_client

由于此变量在每个循环开始时由 socket_accept 重置,一旦另一个 socket_accept 成功,线程可能不会在同一套接字上调用 socket_recv