在Python中使用ssl context.set_servername_callback



我的目标是允许ssl客户端从服务器的多个有效证书对中进行选择。客户端有一个CA证书,它将使用该证书来验证来自服务器的证书。

因此,为了实现这一点,我将服务器上的ssl.SSLContext.set_servername_callback()ssl.SSLSocket.wrap_socket's parameter:server_hostname`结合使用,以允许客户端指定要使用的密钥对。代码如下:

服务器代码:

import sys
import pickle
import ssl
import socket
import select
request = {'msgtype': 0, 'value': 'Ping', 'test': [chr(i) for i in range(256)]}
response = {'msgtype': 1, 'value': 'Pong'}
def handle_client(c, a):
print("Connection from {}:{}".format(*a))
req_raw = c.recv(10000)
req = pickle.loads(req_raw)
print("Received message: {}".format(req))
res = pickle.dumps(response)
print("Sending message: {}".format(response))
c.send(res)
def run_server(hostname, port):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.bind((hostname, port))
s.listen(8)
print("Serving on {}:{}".format(hostname, port))
try:
while True:
(c, a) = s.accept()
def servername_callback(sock, req_hostname, cb_context, as_callback=True):
print('Loading certs for {}'.format(req_hostname))
server_cert = "ssl/{}/server".format(req_hostname)  # NOTE: This use of socket input is INSECURE
cb_context.load_cert_chain(certfile="{}.crt".format(server_cert), keyfile="{}.key".format(server_cert))
# Seems like this is designed usage: https://github.com/python/cpython/blob/3.4/Modules/_ssl.c#L1469
sock.context = cb_context
return None
context = ssl.create_default_context(purpose=ssl.Purpose.CLIENT_AUTH)
context.set_servername_callback(servername_callback)
default_cert = "ssl/3.1/server"
context.load_cert_chain(certfile="{}.crt".format(default_cert), keyfile="{}.key".format(default_cert))
ssl_sock = context.wrap_socket(c, server_side=True)
try:
handle_client(ssl_sock, a)
finally:
c.close()
except KeyboardInterrupt:
s.close()
if __name__ == '__main__':
hostname = ''
port = 6789
run_server(hostname, port)

客户代码:

import sys
import pickle
import socket
import ssl
request = {'msgtype': 0, 'value': 'Ping', 'test': [chr(i) for i in range(256)]}
response = {'msgtype': 1, 'value': 'Pong'}

def client(hostname, port):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
print("Connecting to {}:{}".format(hostname, port))
s.connect((hostname, port))
ssl_sock = ssl.SSLSocket(sock=s, ca_certs="server_old.crt", cert_reqs=ssl.CERT_REQUIRED, server_hostname='3.2')
print("Sending message: {}".format(request))
req = pickle.dumps(request)
ssl_sock.send(req)
resp_raw = ssl_sock.recv(10000)
resp = pickle.loads(resp_raw)
print("Received message: {}".format(resp))
ssl_sock.close()
if __name__ == '__main__':
hostname = 'localhost'
port = 6789
client(hostname, port)

但它不起作用。似乎发生的情况是,servername_callback正在被调用,正在获得指定的"主机名",并且在回调中对context.load_cert_chain的调用没有失败(尽管如果给定的路径不存在,它确实会失败)。但是,服务器总是返回在调用context.wrap_socket(c, server_side=True)之前加载的证书对。因此,我的问题是:在servername_callback中,是否有某种方法可以修改ssl上下文使用的密钥对,并获得用于连接的密钥对证书?

我还应该注意,我检查了流量,直到servername_callback函数返回后,服务器的证书才会被发送(如果它不能成功完成,或者返回"失败"值,则永远不会被发送)。

在回调中,cb_context与调用wrap_socket()的上下文相同,与socket.context相同,因此socket.context = cb_context将上下文设置为与以前相同。

更改上下文的证书链不会影响用于当前wrap_socket()操作的证书。对此的解释在于openssl如何创建其底层对象,在这种情况下,底层SSL结构已经创建并使用链的副本:

NOTES

调用SSL_new()时,与SSL_CTX结构关联的链将复制到任何SSL结构。SSL结构将不受父SSL_CTX中随后更改的任何链的影响。

设置新上下文时,SSL结构会更新,但当新上下文与旧上下文相等时,不会执行该更新。

您需要将sock.context设置为不同的上下文才能使其工作。您当前在每个新的传入连接上实例化一个新的上下文,这是不需要的。相反,您应该只实例化一次标准上下文并重用它。动态加载的上下文也是如此,你可以在启动时创建它们,并将它们放在dict中,这样你就可以进行查找,例如:

...
contexts = {}
for hostname in os.listdir("ssl"):
print('Loading certs for {}'.format(hostname))
server_cert = "ssl/{}/server".format(hostname)
context = ssl.create_default_context(purpose=ssl.Purpose.CLIENT_AUTH)
context.load_cert_chain(certfile="{}.crt".format(server_cert),
keyfile="{}.key".format(server_cert))
contexts[hostname] = context
def servername_callback(sock, req_hostname, cb_context, as_callback=True):
context = contexts.get(req_hostname)
if context is not None:
sock.context = context
else:
pass  # handle unknown hostname case
def run_server(hostname, port):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.bind((hostname, port))
s.listen(8)
print("Serving on {}:{}".format(hostname, port))
context = ssl.create_default_context(purpose=ssl.Purpose.CLIENT_AUTH)
context.set_servername_callback(servername_callback)
default_cert = "ssl/3.1/server"
context.load_cert_chain(certfile="{}.crt".format(default_cert),
keyfile="{}.key".format(default_cert))
try:
while True:
(c, a) = s.accept()
ssl_sock = context.wrap_socket(c, server_side=True)
try:
handle_client(ssl_sock, a)
finally:
c.close()
except KeyboardInterrupt:
s.close()

所以在看了这篇文章和其他一些在线文章后,我整理了上面代码的一个版本,它非常适合我。。。所以我只是想分享一下。以防对其他人有帮助。

import sys
import ssl
import socket
import os
from pprint import pprint
DOMAIN_CONTEXTS = {}
ssl_root_path = "c:/ssl/"
# ----------------------------------------------------------------------------------------------------------------------
#
# As an example create domains in the ssl root path...ie
#
# c:/ssl/example.com
# c:/ssl/johndoe.com
# c:/ssl/test.com
#
# And then create self signed ssl certificates for each domain to test... and put them in the corresponding domain 
# directory... in this case the cert and key files are called cert.pem, and key.pem.... 
#
def setup_ssl_certs():
global DOMAIN_CONTEXTS
for hostname in os.listdir(ssl_root_path):
#print('Loading certs for {}'.format(hostname))
# Establish the certificate and key folder...for the various domains...
server_cert = '{rp}{hn}/'.format(rp=ssl_root_path, hn=hostname)
# Setup the SSL Context manager object, for authentication
context = ssl.create_default_context(purpose=ssl.Purpose.CLIENT_AUTH)
# Load the certificate file, and key file...into the context manager.
context.load_cert_chain(certfile="{}cert.pem".format(server_cert), keyfile="{}key.pem".format(server_cert))
# Set the context object to the global dictionary
DOMAIN_CONTEXTS[hostname] = context
# Uncomment for testing only.
#pprint(contexts)
# ----------------------------------------------------------------------------------------------------------------------
def servername_callback(sock, req_hostname, cb_context, as_callback=True):
"""
This is a callback function for the SSL Context manager, this is what does the real work of pulling the
domain name in the origional request.
"""
# Uncomment for testing only
#print(sock)
#print(req_hostname)
#print(cb_context)
context = DOMAIN_CONTEXTS.get(req_hostname)
if context:
try:
sock.context = context
except Exception as error:
print(error)
else:
sock.server_hostname = req_hostname
else:
pass  # handle unknown hostname case

def handle_client(conn, a):
request_domain = conn.server_hostname
request = conn.recv()
client_ip = conn.getpeername()[0]
resp = 'Hello {cip} welcome, from domain {d} !'.format(cip=client_ip, d=request_domain)
conn.write(b'HTTP/1.1 200 OKnn%s' % resp.encode())

def run_server(hostname, port):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.bind((hostname, port))
s.listen(8)
#print("Serving on {}:{}".format(hostname, port))
context = ssl.create_default_context(purpose=ssl.Purpose.CLIENT_AUTH)
# For Python 3.4+
context.set_servername_callback(servername_callback)
# Only available in 3.7 !!!! have not tested it yet...
#context.sni_callback(servername_callback)
default_cert = "{rp}default/".format(rp=ssl_root_path)
context.load_cert_chain(certfile="{}cert.pem".format(default_cert), keyfile="{}key.pem".format(default_cert))
context.options |= ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1  # optional
context.set_ciphers('EECDH+AESGCM:EDH+AESGCM:AES256+EECDH:AES256+EDH')
try:
while True:
ssock, addr = s.accept()
try:
conn = context.wrap_socket(ssock, server_side=True)
except Exception as error:
print('!!! Error, {e}'.format(e=error))
except ssl.SSLError as e:
print(e)
else:
handle_client(conn, addr)
if conn:
conn.close()
#print('Connection closed !')
except KeyboardInterrupt:
s.close()
# ----------------------------------------------------------------------------------------------------------------------
def main():
setup_ssl_certs()
# Don't forget to update your static name resolution...  ie example.com = 127.0.0.1
run_server('example.com', 443)
# ----------------------------------------------------------------------------------------------------------------------
if __name__ == '__main__':
main()

最新更新