C++ 非阻塞 send/recv 只在一侧工作(客户端或服务器)

C++ non-blocking send/recv work only on one side(client or server)

我正在使用非阻塞套接字在 C++ 中实现一个服务器。

因为我想在客户端和服务器之间发送消息,所以我围绕 send/recv 系统调用编写了 2 个包装器。主要是,我想在每条消息前加上 4Bytes(消息长度),以便接收方知道执行 recv 需要多长时间。

此外,我有一个 client/server 程序,每个程序都启动一个套接字并在本地主机上侦听。 然后客户端发送一条随机消息,服务器接收。

但是,当我尝试从服务器发送到客户端时,两个程序都停止了。

我已经多次测试包装器并且它们 read/deliver 数据,但是每当我尝试在先前发送的连接上接收时,问题就来了。

我知道 secure_recv 中存在内存泄漏,但我需要它来通过一些自定义测试,这些测试写得不是很好。

问题出在select,其中returns是一个正数,但后来我再也没有进入if (FD_ISSET(fd, &readset))声明。

我做错了什么,我们该如何解决?非常感谢!

编辑 我的问题是套接字在 select 函数处阻塞(忙于工作)。我更新了代码,使 secure_* 函数中没有 select。这是一种更好的方法,首先通过 select 检查套接字是否可用于 client/server 线程级别的 send/recv,然后调用 secure_* 函数。问题暂时得到解答。

client.cpp

// Client side C/C++ program to demonstrate Socket programming
#include <stdio.h>
#include <sys/socket.h>
#include <arpa/inet.h>
#include <unistd.h>
#include <string.h>
#include "util.h"
#define PORT 8080

int main(int argc, char const *argv[])
{
    int sock = 0, valread;
    struct sockaddr_in serv_addr;
    if ((sock = socket(AF_INET, SOCK_STREAM, 0)) < 0)
    {
        printf("\n Socket creation error \n");
        return -1;
    }

    serv_addr.sin_family = AF_INET;
    serv_addr.sin_port = htons(PORT);

    // Convert IPv4 and IPv6 addresses from text to binary form
    if (inet_pton(AF_INET, "127.0.0.1", &serv_addr.sin_addr) <= 0)
    {
        printf("\nInvalid address/ Address not supported \n");
        return -1;
    }

    if (connect(sock, (struct sockaddr *)&serv_addr, sizeof(serv_addr)) < 0)
    {
        printf("\nConnection Failed \n");
        return -1;
    }

    int numbytes;
    size_t size = 0;
    std::unique_ptr<char[]> buf = get_rand_data(size);
    if ((numbytes = secure_send(sock, buf.get(), size, 0)) == -1)
    {
        std::cout << std::strerror(errno) << "\n";
        exit(1);
    }
    std::cout << "Client sent : " << numbytes << "\n";

    int64_t bytecount = -1;
    while (1)
    {
        std::unique_ptr<char[]> buffer;
        if ((bytecount = secure_recv(sock, buffer, 0)) <= 0)
        {
            if (bytecount == 0)
            {
                break;
            }
        }
        std::cout << bytecount << "\n";
    }

    std::cout << "Client received : " << bytecount << "\n";

    close(sock);

    return 0;
}

server.cpp

// Server side C/C++ program to demonstrate Socket programming
#include <unistd.h>
#include <stdio.h>
#include <sys/socket.h>
#include <stdlib.h>
#include <netinet/in.h>
#include <string.h>
#include "util.h"

#define PORT 8080
int main(int argc, char const *argv[])
{
    int server_fd, new_socket, valread;
    struct sockaddr_in address;
    int opt = 1;
    int addrlen = sizeof(address);
    
    // Creating socket file descriptor
    if ((server_fd = socket(AF_INET, SOCK_STREAM, 0)) == 0)
    {
        perror("socket failed");
        exit(EXIT_FAILURE);
    }
    
    // Forcefully attaching socket to the port 8080
    if (setsockopt(server_fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT,
                                                &opt, sizeof(opt)))
    {
        perror("setsockopt");
        exit(EXIT_FAILURE);
    }
    address.sin_family = AF_INET;
    address.sin_addr.s_addr = INADDR_ANY;
    address.sin_port = htons( PORT );
    
    // Forcefully attaching socket to the port 8080
    if (bind(server_fd, (struct sockaddr *)&address,
                                sizeof(address))<0)
    {
        perror("bind failed");
        exit(EXIT_FAILURE);
    }
    if (listen(server_fd, 3) < 0)
    {
        perror("listen");
        exit(EXIT_FAILURE);
    }
    if ((new_socket = accept(server_fd, (struct sockaddr *)&address,
                    (socklen_t*)&addrlen))<0)
    {
        perror("accept");
        exit(EXIT_FAILURE);
    }
    
     // set the socket to non-blocking mode
    fcntl(new_socket, F_SETFL, O_NONBLOCK);


    int64_t bytecount = -1;
    while (1) {
        std::unique_ptr<char[]> buffer;
        if ((bytecount = secure_recv(new_socket, buffer, 0))  <= 0) {
            if (bytecount == 0) {
                    break;
            }
        }
        std::cout << bytecount << "\n";
    }



    int numbytes;
    size_t size = 0;
    std::unique_ptr<char[]> buf = get_rand_data(size);
    if ((numbytes = secure_send(new_socket, buf.get(), size, 0)) == -1)
    {
        std::cout << std::strerror(errno) << "\n";
        exit(1);
    }
    std::cout << "Client sent : " << numbytes << "\n";


    close(new_socket);

    return 0;
}

util.h

#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <errno.h>
#include <string.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <time.h>
#include <unistd.h>
#include <fcntl.h>
#include <sys/wait.h>
#include <signal.h>
#include <cstring>
#include <iostream>
#include <thread>
#include <vector>
#include <mutex>
#include <condition_variable>
#include <algorithm>
#include <memory>
#include <poll.h>
#include <iomanip>

/**
 * It takes as arguments one char[] array of 4 or bigger size and an integer.
 * It converts the integer into a byte array.
 */
void convertIntToByteArray(char *dst, int sz)
{
    auto tmp = dst;
    tmp[0] = (sz >> 24) & 0xFF;
    tmp[1] = (sz >> 16) & 0xFF;
    tmp[2] = (sz >> 8) & 0xFF;
    tmp[3] = sz & 0xFF;
}

/**
 * It takes as an argument a ptr to an array of size 4 or bigger and 
 * converts the char array into an integer.
 */
int convertByteArrayToInt(char *b)
{
    return (b[0] << 24) + ((b[1] & 0xFF) << 16) + ((b[2] & 0xFF) << 8) + (b[3] & 0xFF);
}

/**
 * It constructs the message to be sent. 
 * It takes as arguments a destination char ptr, the payload (data to be sent)
 * and the payload size.
 * It returns the expected message format at dst ptr;
 *
 *  |<---msg size (4 bytes)--->|<---payload (msg size bytes)--->|
 *
 *
 */
void construct_message(char *dst, char *payload, size_t payload_size)
{
    convertIntToByteArray(dst, payload_size);

    memcpy(dst + 4, payload, payload_size);
}

/**
 * It returns the actual size of msg.
 * Not that msg might not contain all payload data. 
 * The function expects at least that the msg contains the first 4 bytes that
 * indicate the actual size of the payload.
 */
int get_payload_size(char *msg, size_t bytes)
{
    // TODO:
    return convertByteArrayToInt(msg);
}

/**
 * Sends to the connection defined by the fd, a message with a payload (data) of size len bytes.
 * The fd should be non-blocking socket.
 */

/**
 * Receives a message from the fd (non-blocking) and stores it in buf.
 */
int secure_recv(int fd, std::unique_ptr<char[]> &buf)
{
    // TODO:
    int valread = 0;
    int len = 0;
    int _len = 4;
    bool once_received = false;

    std::vector<char> ptr(4);

    while (_len > 0)
    {

        int _valread = recv(fd, ptr.data() + valread, _len, MSG_DONTWAIT);

        if (_valread == 0)
        {
            break;
        }
        if (_valread > 0)
        {
            _len -= _valread;
            valread += _valread;
        }

        if (!once_received && valread == 4)
        {

            once_received = true;

            len = convertByteArrayToInt(ptr.data());

            _len = len;

            ptr = std::vector<char>(len);

            valread = 0;
        }
    }

    buf = std::make_unique<char[]>(len);

    memcpy(buf.get(), ptr.data(), len);

    return len;
}

/**
 * Sends to the connection defined by the fd, a message with a payload (data) of size len bytes.
 * The fd should be non-blocking socket.
 */
int secure_send(int fd, char *data, size_t len)
{
    // TODO:

    char ptr[len + 4];
    int valsent = 0;
    int _len = 4;
    bool once_sent = false;

    construct_message(ptr, data, len);

    while (_len > 0)
    {

        int _valsent = send(fd, ptr + valsent, _len, MSG_DONTWAIT);

        if (_valsent == 0)
        {
            break;
        }

        if (_valsent > 0)
        {
            _len -= _valsent;
            valsent += _valsent;
        }

        if (!once_sent && valsent == 4)
        {

            once_sent = true;

            _len = len;
        }
    }

    return len;
}

编译通过

g++ -O3 -std=c++17 -Wall -g -I../ client.cpp -o client -lpthread

让我们从写入循环开始:

while (1)
{

    // std::cerr << "first iteration send\n";

    FD_ZERO(&writeset);
    FD_SET(fd, &writeset);

    if (select(fd + 1, NULL, &writeset, NULL, NULL) > 0)
    {

        if (FD_ISSET(fd, &writeset))
        {
            valsent = send(fd, ptr + valsent, _len, 0);

糟糕。这会丢失 valsent,它会跟踪您到目前为止已发送的字节数。所以在你的第三个循环中,ptr + valsent 只会添加第二次收到的字节数。您需要跟踪到目前为止发送到某处的总字节数。

            if (valsent <= 0)
            {
                break;
            }

            _len -= valsent;

如果 _len 变为零怎么办?您仍然会调用 select 甚至 send。您可能希望 while (1) 成为 while (_len > 0).

        }
    }
}
return len;

现在,进入读取循环:

    if (select(fd + 1, &readset, NULL, NULL, NULL) > 0)
    {

        if (FD_ISSET(fd, &readset))
        {

            if (first_iteration)
            {
                recv(fd, ptr, 4, 0);

你在这里忽略 recv 的 return 值。如果不是 4 怎么办?

                len = convertByteArrayToInt(ptr);

                buf = std::make_unique<char[]>(len);

                _len = len;
                first_iteration = false;
            }

            valread = recv(fd, buf.get() + valread, _len, 0);

            if (valread <= 0)
            {
                break;
            }

            _len -= valread;

如果 _len 为零,则不要离开循环。您将再次调用 select,等待可能永远不会到来的数据。

        }
    }