HTTP(S) 代理的简单低级实现不起作用

Simple low level implementation of HTTP(S) proxy is not working

我正在 Java 中实现简单的 HTTP(S) proxy/tunnel。所有这些都是在使用套接字创建隧道的多线程应用程序中完成的。我提出的代码适用于 HTTP,但不适用于 HTTPS。我遇到的问题是链中的最后一个服务器(通过我的隧道)在尝试打开到目标服务器的套接字连接时抛出连接超时。

换句话说,使用 HTTPS 连接:

Request:
client ---> tunnel-srv1 ---> tunnel-srv2 ---> HTTPS-server (https://facebook.com for example)
                                            ^ Timeout here on new Socket("facebook.com", 443);

tunnel-srv1(省略无关代码):

public class ConnectionMainThread implements Runnable {

    private Socket clientSocket;
    private Socket remoteSocket;

    public ConnectionMainThread(Socket clientSocket) {
        this.clientSocket = clientSocket;
    }

    @Override
    public void run() {
        try {
            // Open connection to tunnel-srv2 server
            remoteSocket = new Socket("127.0.0.1", 8081);

            // Create thread client->remote
            Thread clientToRemote = new ReceiveSendThread(clientSocket, remoteSocket, true);
            clientToRemote.start();

            // Create thread remote->client
            Thread remoteToClient = new ReceiveSendThread(remoteSocket, clientSocket);
            remoteToClient.start();

            // Block thread until both other threads are released
            clientToRemote.join();
            remoteToClient.join();

            // Make sure all the connection are closed
            remoteSocket.close();
            clientSocket.close();
        } catch (Exception e) {
            //e.printStackTrace();
        }
    }
}

public class ReceiveSendThread extends Thread {

    private Socket inSocket;
    private Socket outSocket;
    private boolean isFromClient = false;

    public ReceiveSendThread(Socket inSocket, Socket outSocket, boolean isFromClient) {
        this.inSocket = inSocket;
        this.outSocket = outSocket;
        this.isFromClient = isFromClient;
    }

    public ReceiveSendThread(Socket inSocket, Socket outSocket) {
        this.inSocket = inSocket;
        this.outSocket = outSocket;
    }

    @Override
    public void run() {
        try (
            OutputStream out = outSocket.getOutputStream();
            InputStream in = inSocket.getInputStream();
        ) {
            int nRead;
            byte[] data = new byte[16384];

            if (this.isFromClient) {
                nRead = in.read(data, 0 ,data.length);
                byte[] first = new byte[4096];
                System.arraycopy(data, 0, first, 0, first.length);
                String str = new String(first);
                String uri = (str.split(" "))[1];
                URL url = new URL(uri);
                String host = url.getURL();
                byte[] bytes = host.getBytes();
                byte[] firstLine = new byte[4096];
                System.arraycopy(bytes, 0, firstLine, 0, bytes.length);
                out.write(firstLine);
                out.flush();
                out.write(data, 0, nRead);
                out.flush();
            }


            while ((nRead = in.read(data, 0 ,data.length)) != -1) {
                out.write(data, 0, nRead);
                out.flush();
            }

            // Input socket is closing
            outSocket.shutdownOutput();

        } catch (IOException e) {
            e.printStackTrace();
        }
    }
}
public class URL {

    private String protocol = "http";
    private String host = null;
    private String port = null;
    private String path = "/";

    public URL(String str) throws MalformedURLException {
        int protoPos = str.indexOf("://");
        if (protoPos != -1) {
            str = str.substring(protoPos+3);
        }
        int pathPos = str.indexOf("/");
        if (pathPos == 0) {
            throw new MalformedURLException("no hostname present");
        }
        if (pathPos > 0) {
            str = str.substring(0, pathPos);
        }
        int portDelimiterPos = str.indexOf(":");
        if (portDelimiterPos == -1) {
            port = getDefaultPort();
            host = str;
        } else {
            port = str.substring(portDelimiterPos + 1);
            host = str.substring(0, portDelimiterPos);
        }
    }

    public String getURL(){
        return host+":"+port;
    }

    private String getDefaultPort(){
        String port;
        switch (this.protocol){
            case "http":
                port = "80";
                break;
            case "https":
                port = "443";
                break;
            case "ws":
                port = "80";
                break;
            default:
                port = "80";
        }
        return port;
    }
}

tunnel-srv2(省略无关代码):

public class ConnectionMainThread implements Runnable {

    private Socket clientSocket;
    private Socket remoteSocket;

    private InetAddress addr;
    private Integer port;
    private String str;

    public ConnectionMainThread(Socket clientSocket) {
        this.clientSocket = clientSocket;
    }

    @Override
    public void run() {
        try {
            // Find remote server location (read first line)
            InputStream in = clientSocket.getInputStream();
            int nRead;
            byte[] data = new byte[4096];

            nRead = in.read(data, 0 ,data.length);
            if (nRead < 0) {
                clientSocket.shutdownOutput();
                clientSocket.close();
                return;
            }
            str = new String(data, 0, nRead);
            String[] res = str.split("\:");
            String host = res[0];
            port = Integer.parseInt(res[1].replaceAll("\D+", ""));
            addr = Inet4Address.getByName(host);
            //String addr = host;


            // Open connection to destination server
            remoteSocket = new Socket(addr, port);

            // Create thread client->remote
            Thread clientToRemote = new ReceiveSendThread(clientSocket, remoteSocket, true);
            clientToRemote.start();

            // Create thread remote->client
            Thread remoteToClient = new ReceiveSendThread(remoteSocket, clientSocket);
            remoteToClient.start();

            // Block thread until both other threads are released
            clientToRemote.join();
            remoteToClient.join();

            // Make sure all the connection are closed
            remoteSocket.close();
            clientSocket.close();
        } catch (ConnectException e) {
            e.printStackTrace();
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}
public class ReceiveSendThread extends Thread {

    private Socket inSocket;
    private Socket outSocket;
    private boolean dump = false;

    public ReceiveSendThread(Socket inSocket, Socket outSocket, boolean dump) {
        this.inSocket = inSocket;
        this.outSocket = outSocket;
        this.dump = dump;
    }

    public ReceiveSendThread(Socket inSocket, Socket outSocket) {
        this.inSocket = inSocket;
        this.outSocket = outSocket;
    }

    @Override
    public void run() {
        try (
            OutputStream out = outSocket.getOutputStream();
            InputStream in = inSocket.getInputStream();
        ) {
            int nRead;
            byte[] data = new byte[16384];

            while ((nRead = in.read(data, 0 ,data.length)) != -1) {
                if (this.dump) {
                    System.out.print(new String(data, 0, nRead));
                }
                out.write(data, 0, nRead);
                out.flush();
            }

            // Input socket is closing
            outSocket.shutdownOutput();

        } catch (IOException e) {
            //e.printStackTrace();
        }
    }
}

原来的问题不在代码上,而是在逻辑上。

来自 RFC documentation

A CONNECT method requests that a proxy establish a tunnel connection
on its behalf.

这意味着代理在收到 CONNECT 请求后应该建立到目的地的隧道,这反过来又意味着 CONNECT 请求并不意味着到达目的地并且仅用于代理服务器,而问题中的代码也会发送这个请求到目的地。尽管 RFC 提到在这种情况下,接收此请求的服务器可能会使用 2xx 代码进行响应,但这不是强制性的,而且似乎不是它在测试服务器上的工作方式。

所以解决方案实际上是正确解析传入请求,在收到 CONNECT 请求的情况下,只需打开连接而不将其发送到目的地。