Make the ssl handshake non blocking.

This commit is contained in:
poljar (Damir Jelić) 2018-01-30 19:29:03 +01:00
parent ad5472d5f2
commit d83c25e709
3 changed files with 103 additions and 58 deletions

158
main.py
View file

@ -127,40 +127,85 @@ def wrap_socket(server, file_descriptor):
else:
sock = temp_socket
try:
start = time.time()
message = "{prefix}matrix: Doing SSL handshake...".format(
prefix=W.prefix("network"))
sock.setblocking(False)
W.prnt(server.server_buffer, message)
message = "{prefix}matrix: Doing SSL handshake...".format(
prefix=W.prefix("network"))
W.prnt(server.server_buffer, message)
# TODO this blocks currently
ssl_socket = server.ssl_context.wrap_socket(
sock,
server_hostname=server.address) # type: ssl.SSLSocket
ssl_socket = server.ssl_context.wrap_socket(
sock,
do_handshake_on_connect=False,
server_hostname=server.address) # type: ssl.SSLSocket
cipher = ssl_socket.cipher()
cipher_message = ("{prefix}matrix: Connected using {tls}, and "
"{bit} bit {cipher} cipher suite.").format(
prefix=W.prefix("network"),
tls=cipher[1],
bit=cipher[2],
cipher=cipher[0])
server.socket = ssl_socket
W.prnt(server.server_buffer, cipher_message)
try_ssl_handshake(server)
# TODO print out the certificates
# cert = ssl_socket.getpeercert()
# W.prnt(server.server_buffer, pprint.pformat(cert))
server.lag = (time.time() - start) * 1000
W.bar_item_update("lag")
@utf8_decode
def ssl_fd_cb(server_name, file_descriptor):
server = SERVERS[server_name]
return ssl_socket
# TODO add finer grained error messages with the subclass exceptions
except ssl.SSLError as error:
server_buffer_prnt(server, str(error))
return None
if server.ssl_hook:
W.unhook(server.ssl_hook)
server.ssl_hook = None
try_ssl_handshake(server)
return W.WEECHAT_RC_OK
def try_ssl_handshake(server):
socket = server.socket
while True:
try:
socket.do_handshake()
cipher = socket.cipher()
cipher_message = ("{prefix}matrix: Connected using {tls}, and "
"{bit} bit {cipher} cipher suite.").format(
prefix=W.prefix("network"),
tls=cipher[1],
bit=cipher[2],
cipher=cipher[0])
W.prnt(server.server_buffer, cipher_message)
# TODO print out the certificates
# cert = socket.getpeercert()
# W.prnt(server.server_buffer, pprint.pformat(cert))
finalize_connection(server)
return True
except ssl.SSLWantReadError:
hook = W.hook_fd(
server.socket.fileno(),
1, 0, 0,
"ssl_fd_cb",
server.name
)
server.ssl_hook = hook
return False
except ssl.SSLWantWriteError:
hook = W.hook_fd(
server.socket.fileno(),
0, 1, 0,
"ssl_fd_cb",
server.name
)
server.ssl_hook = hook
return False
except ssl.SSLError as error:
server_buffer_prnt(server, str(error))
matrix_server_reconnect(server)
return False
@utf8_decode
@ -230,6 +275,32 @@ def receive_cb(server_name, file_descriptor):
return W.WEECHAT_RC_OK
def finalize_connection(server):
hook = W.hook_fd(
server.socket.fileno(),
1, 0, 0,
"receive_cb",
server.name
)
if not server.timer_hook:
server.timer_hook = W.hook_timer(
1 * 1000,
0,
0,
"matrix_timer_cb",
server.name
)
server.fd_hook = hook
server.connected = True
server.connecting = False
server.reconnect_count = 0
if not server.access_token:
matrix_login(server)
@utf8_decode
def connect_cb(data, status, gnutls_rc, sock, error, ip_address):
# pylint: disable=too-many-arguments,too-many-branches
@ -238,38 +309,11 @@ def connect_cb(data, status, gnutls_rc, sock, error, ip_address):
if status_value == W.WEECHAT_HOOK_CONNECT_OK:
file_descriptor = int(sock) # type: int
sock = wrap_socket(server, file_descriptor)
server.numeric_address = ip_address
server_buffer_set_title(server)
if sock:
server.socket = sock
hook = W.hook_fd(
server.socket.fileno(),
1, 0, 0,
"receive_cb",
server.name
)
wrap_socket(server, file_descriptor)
if not server.timer_hook:
server.timer_hook = W.hook_timer(
1 * 1000,
0,
0,
"matrix_timer_cb",
server.name
)
server.fd_hook = hook
server.connected = True
server.connecting = False
server.reconnect_count = 0
server.numeric_address = ip_address
server_buffer_set_title(server)
if not server.access_token:
matrix_login(server)
else:
matrix_server_reconnect(server)
return W.WEECHAT_RC_OK
elif status_value == W.WEECHAT_HOOK_CONNECT_ADDRESS_NOT_FOUND:

View file

@ -48,6 +48,7 @@ class MatrixServer:
self.buffers = dict() # type: Dict[str, weechat.buffer]
self.server_buffer = None # type: weechat.buffer
self.fd_hook = None # type: weechat.hook
self.ssl_hook = None # type: weechat.hook
self.timer_hook = None # type: weechat.hook
self.numeric_address = "" # type: str

View file

@ -93,7 +93,7 @@ def server_buffer_set_title(server):
else:
ip_string = ""
title = ("Matrix: {address}/{port}{ip}").format(
title = ("Matrix: {address}:{port}{ip}").format(
address=server.address,
port=server.port,
ip=ip_string)