在 Python 中使用 ssl context.set_servername_callback

Using ssl context.set_servername_callback in Python

我的目标是允许 ssl 客户端 select 来自服务器的多个有效证书对。客户端有一个 CA 证书,它将用于验证来自服务器的证书。

因此,为了尝试实现这一点,我在服务器上结合使用 ssl.SSLContext.set_servername_callback()ssl.SSLSocket.wrap_socket's parameter:server_hostname` 来尝试允许客户端指定要发送的密钥对采用。代码如下所示:

服务器代码:

import sys
import pickle
import ssl
import socket
import select

request = {'msgtype': 0, 'value': 'Ping', 'test': [chr(i) for i in range(256)]}
response = {'msgtype': 1, 'value': 'Pong'}

def handle_client(c, a):
    print("Connection from {}:{}".format(*a))
    req_raw = c.recv(10000)
    req = pickle.loads(req_raw)
    print("Received message: {}".format(req))
    res = pickle.dumps(response)
    print("Sending message: {}".format(response))
    c.send(res)

def run_server(hostname, port):
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    s.bind((hostname, port))
    s.listen(8)
    print("Serving on {}:{}".format(hostname, port))

    try:
        while True:
            (c, a) = s.accept()

            def servername_callback(sock, req_hostname, cb_context, as_callback=True):
                print('Loading certs for {}'.format(req_hostname))
                server_cert = "ssl/{}/server".format(req_hostname)  # NOTE: This use of socket input is INSECURE
                cb_context.load_cert_chain(certfile="{}.crt".format(server_cert), keyfile="{}.key".format(server_cert))

                # Seems like this is designed usage: https://github.com/python/cpython/blob/3.4/Modules/_ssl.c#L1469
                sock.context = cb_context
                return None

            context = ssl.create_default_context(purpose=ssl.Purpose.CLIENT_AUTH)
            context.set_servername_callback(servername_callback)
            default_cert = "ssl/3.1/server"
            context.load_cert_chain(certfile="{}.crt".format(default_cert), keyfile="{}.key".format(default_cert))
            ssl_sock = context.wrap_socket(c, server_side=True)

            try:
                handle_client(ssl_sock, a)
            finally:
                c.close()

    except KeyboardInterrupt:
        s.close()

if __name__ == '__main__':
    hostname = ''
    port = 6789
    run_server(hostname, port)

客户代码:

import sys
import pickle
import socket
import ssl

request = {'msgtype': 0, 'value': 'Ping', 'test': [chr(i) for i in range(256)]}
response = {'msgtype': 1, 'value': 'Pong'}


def client(hostname, port):
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    print("Connecting to {}:{}".format(hostname, port))
    s.connect((hostname, port))

    ssl_sock = ssl.SSLSocket(sock=s, ca_certs="server_old.crt", cert_reqs=ssl.CERT_REQUIRED, server_hostname='3.2')

    print("Sending message: {}".format(request))
    req = pickle.dumps(request)
    ssl_sock.send(req)

    resp_raw = ssl_sock.recv(10000)
    resp = pickle.loads(resp_raw)
    print("Received message: {}".format(resp))

    ssl_sock.close()

if __name__ == '__main__':
    hostname = 'localhost'
    port = 6789
    client(hostname, port)

但它不起作用。似乎正在发生的事情是 servername_callback 被调用,正在获取指定的 "hostname",并且在回调中对 context.load_cert_chain 的调用没有失败(尽管如果给定路径它确实会失败不存在)。但是,服务器总是 returns 在调用 context.wrap_socket(c, server_side=True) 之前加载的证书对。所以我的问题是:在 servername_callback 中是否有某种方法可以修改 ssl 上下文使用的密钥对,并获取该密钥对的证书以用于连接?

我还应该注意,我检查了流量,直到 servername_callback 函数 returns 之后才发送服务器证书(如果未能成功完成,则永远不会发送,或 returns 一个 "failure" 值)。

在您的回调中,cb_context 与调用 wrap_socket() 的上下文相同,并且与 socket.context 相同,因此 socket.context = cb_context 将上下文设置为相同以前是。

更改上下文的证书链不会影响用于当前 wrap_socket() 操作的证书。对此的解释在于 openssl 如何创建其底层对象,在这种情况下,底层 SSL 结构已经创建并使用 copies of the chains:

NOTES

The chains associate with an SSL_CTX structure are copied to any SSL structures when SSL_new() is called. SSL structures will not be affected by any chains subsequently changed in the parent SSL_CTX.

设置新上下文时,会更新 SSL 结构,但 new context is equal to the old one 时不会执行该更新。

您需要将 sock.context 设置为 不同的 上下文才能使其正常工作。您目前在每个新的传入连接上实例化一个新的上下文,这是不需要的。相反,您应该只实例化您的标准上下文一次并重用它。动态加载的上下文也是如此,您可以在启动时创建它们并将它们放入字典中,这样您就可以进行查找,例如:

...

contexts = {}

for hostname in os.listdir("ssl"):
    print('Loading certs for {}'.format(hostname))
    server_cert = "ssl/{}/server".format(hostname)
    context = ssl.create_default_context(purpose=ssl.Purpose.CLIENT_AUTH)
    context.load_cert_chain(certfile="{}.crt".format(server_cert),
                            keyfile="{}.key".format(server_cert))
    contexts[hostname] = context

def servername_callback(sock, req_hostname, cb_context, as_callback=True):
    context = contexts.get(req_hostname)
    if context is not None:
        sock.context = context
    else:
        pass  # handle unknown hostname case

def run_server(hostname, port):
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    s.bind((hostname, port))
    s.listen(8)
    print("Serving on {}:{}".format(hostname, port))

    context = ssl.create_default_context(purpose=ssl.Purpose.CLIENT_AUTH)
    context.set_servername_callback(servername_callback)
    default_cert = "ssl/3.1/server"
    context.load_cert_chain(certfile="{}.crt".format(default_cert),
                            keyfile="{}.key".format(default_cert))

    try:
        while True:
            (c, a) = s.accept()
            ssl_sock = context.wrap_socket(c, server_side=True)
            try:
                handle_client(ssl_sock, a)
            finally:
                c.close()

    except KeyboardInterrupt:
        s.close()

因此,在查看了这个 post 和其他一些在线代码之后,我整理了上面代码的一个版本,它非常适合我...所以我只是想分享一下。以防它帮助其他人。

import sys
import ssl
import socket
import os

from pprint import pprint

DOMAIN_CONTEXTS = {}

ssl_root_path = "c:/ssl/"

# ----------------------------------------------------------------------------------------------------------------------
#
# As an example create domains in the ssl root path...ie
#
# c:/ssl/example.com
# c:/ssl/johndoe.com
# c:/ssl/test.com
#
# And then create self signed ssl certificates for each domain to test... and put them in the corresponding domain 
# directory... in this case the cert and key files are called cert.pem, and key.pem.... 
#

def setup_ssl_certs():

    global DOMAIN_CONTEXTS

    for hostname in os.listdir(ssl_root_path):

        #print('Loading certs for {}'.format(hostname))

        # Establish the certificate and key folder...for the various domains...
        server_cert = '{rp}{hn}/'.format(rp=ssl_root_path, hn=hostname)

        # Setup the SSL Context manager object, for authentication
        context = ssl.create_default_context(purpose=ssl.Purpose.CLIENT_AUTH)

        # Load the certificate file, and key file...into the context manager.
        context.load_cert_chain(certfile="{}cert.pem".format(server_cert), keyfile="{}key.pem".format(server_cert))

        # Set the context object to the global dictionary
        DOMAIN_CONTEXTS[hostname] = context

    # Uncomment for testing only.
    #pprint(contexts)

# ----------------------------------------------------------------------------------------------------------------------

def servername_callback(sock, req_hostname, cb_context, as_callback=True):
    """
    This is a callback function for the SSL Context manager, this is what does the real work of pulling the
    domain name in the origional request.
    """

    # Uncomment for testing only
    #print(sock)
    #print(req_hostname)
    #print(cb_context)

    context = DOMAIN_CONTEXTS.get(req_hostname)

    if context:

        try:
            sock.context = context
        except Exception as error:
            print(error)
        else:
            sock.server_hostname = req_hostname

    else:
        pass  # handle unknown hostname case


def handle_client(conn, a):

    request_domain = conn.server_hostname

    request = conn.recv()

    client_ip = conn.getpeername()[0]

    resp = 'Hello {cip} welcome, from domain {d} !'.format(cip=client_ip, d=request_domain)

    conn.write(b'HTTP/1.1 200 OK\n\n%s' % resp.encode())


def run_server(hostname, port):

    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

    s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)

    s.bind((hostname, port))

    s.listen(8)

    #print("Serving on {}:{}".format(hostname, port))

    context = ssl.create_default_context(purpose=ssl.Purpose.CLIENT_AUTH)

    # For Python 3.4+
    context.set_servername_callback(servername_callback)

    # Only available in 3.7 !!!! have not tested it yet...
    #context.sni_callback(servername_callback)

    default_cert = "{rp}default/".format(rp=ssl_root_path)

    context.load_cert_chain(certfile="{}cert.pem".format(default_cert), keyfile="{}key.pem".format(default_cert))

    context.options |= ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1  # optional

    context.set_ciphers('EECDH+AESGCM:EDH+AESGCM:AES256+EECDH:AES256+EDH')

    try:
        while True:

            ssock, addr = s.accept()

            try:
                conn = context.wrap_socket(ssock, server_side=True)

            except Exception as error:
                print('!!! Error, {e}'.format(e=error))

            except ssl.SSLError as e:
                print(e)

            else:
                handle_client(conn, addr)

                if conn:
                    conn.close()
                    #print('Connection closed !')

    except KeyboardInterrupt:
        s.close()

# ----------------------------------------------------------------------------------------------------------------------

def main():

    setup_ssl_certs()

    # Don't forget to update your static name resolution...  ie example.com = 127.0.0.1
    run_server('example.com', 443)

# ----------------------------------------------------------------------------------------------------------------------

if __name__ == '__main__':
    main()