diff --git a/main.py b/main.py index 8c72d71..0c9cdb2 100644 --- a/main.py +++ b/main.py @@ -127,40 +127,85 @@ def wrap_socket(server, file_descriptor): else: sock = temp_socket - try: - start = time.time() - message = "{prefix}matrix: Doing SSL handshake...".format( - prefix=W.prefix("network")) + sock.setblocking(False) - W.prnt(server.server_buffer, message) + message = "{prefix}matrix: Doing SSL handshake...".format( + prefix=W.prefix("network")) + W.prnt(server.server_buffer, message) - # TODO this blocks currently - ssl_socket = server.ssl_context.wrap_socket( - sock, - server_hostname=server.address) # type: ssl.SSLSocket + ssl_socket = server.ssl_context.wrap_socket( + sock, + do_handshake_on_connect=False, + server_hostname=server.address) # type: ssl.SSLSocket - cipher = ssl_socket.cipher() - cipher_message = ("{prefix}matrix: Connected using {tls}, and " - "{bit} bit {cipher} cipher suite.").format( - prefix=W.prefix("network"), - tls=cipher[1], - bit=cipher[2], - cipher=cipher[0]) + server.socket = ssl_socket - W.prnt(server.server_buffer, cipher_message) + try_ssl_handshake(server) - # TODO print out the certificates - # cert = ssl_socket.getpeercert() - # W.prnt(server.server_buffer, pprint.pformat(cert)) - server.lag = (time.time() - start) * 1000 - W.bar_item_update("lag") +@utf8_decode +def ssl_fd_cb(server_name, file_descriptor): + server = SERVERS[server_name] - return ssl_socket - # TODO add finer grained error messages with the subclass exceptions - except ssl.SSLError as error: - server_buffer_prnt(server, str(error)) - return None + if server.ssl_hook: + W.unhook(server.ssl_hook) + server.ssl_hook = None + + try_ssl_handshake(server) + + return W.WEECHAT_RC_OK + + +def try_ssl_handshake(server): + socket = server.socket + + while True: + try: + socket.do_handshake() + + cipher = socket.cipher() + cipher_message = ("{prefix}matrix: Connected using {tls}, and " + "{bit} bit {cipher} cipher suite.").format( + prefix=W.prefix("network"), + tls=cipher[1], + bit=cipher[2], + cipher=cipher[0]) + W.prnt(server.server_buffer, cipher_message) + + # TODO print out the certificates + # cert = socket.getpeercert() + # W.prnt(server.server_buffer, pprint.pformat(cert)) + + finalize_connection(server) + + return True + + except ssl.SSLWantReadError: + hook = W.hook_fd( + server.socket.fileno(), + 1, 0, 0, + "ssl_fd_cb", + server.name + ) + server.ssl_hook = hook + + return False + + except ssl.SSLWantWriteError: + hook = W.hook_fd( + server.socket.fileno(), + 0, 1, 0, + "ssl_fd_cb", + server.name + ) + server.ssl_hook = hook + + return False + + except ssl.SSLError as error: + server_buffer_prnt(server, str(error)) + matrix_server_reconnect(server) + return False @utf8_decode @@ -230,6 +275,32 @@ def receive_cb(server_name, file_descriptor): return W.WEECHAT_RC_OK +def finalize_connection(server): + hook = W.hook_fd( + server.socket.fileno(), + 1, 0, 0, + "receive_cb", + server.name + ) + + if not server.timer_hook: + server.timer_hook = W.hook_timer( + 1 * 1000, + 0, + 0, + "matrix_timer_cb", + server.name + ) + + server.fd_hook = hook + server.connected = True + server.connecting = False + server.reconnect_count = 0 + + if not server.access_token: + matrix_login(server) + + @utf8_decode def connect_cb(data, status, gnutls_rc, sock, error, ip_address): # pylint: disable=too-many-arguments,too-many-branches @@ -238,38 +309,11 @@ def connect_cb(data, status, gnutls_rc, sock, error, ip_address): if status_value == W.WEECHAT_HOOK_CONNECT_OK: file_descriptor = int(sock) # type: int - sock = wrap_socket(server, file_descriptor) + server.numeric_address = ip_address + server_buffer_set_title(server) - if sock: - server.socket = sock - hook = W.hook_fd( - server.socket.fileno(), - 1, 0, 0, - "receive_cb", - server.name - ) + wrap_socket(server, file_descriptor) - if not server.timer_hook: - server.timer_hook = W.hook_timer( - 1 * 1000, - 0, - 0, - "matrix_timer_cb", - server.name - ) - - server.fd_hook = hook - server.connected = True - server.connecting = False - server.reconnect_count = 0 - server.numeric_address = ip_address - - server_buffer_set_title(server) - - if not server.access_token: - matrix_login(server) - else: - matrix_server_reconnect(server) return W.WEECHAT_RC_OK elif status_value == W.WEECHAT_HOOK_CONNECT_ADDRESS_NOT_FOUND: diff --git a/matrix/server.py b/matrix/server.py index bdf49ff..b4b8eac 100644 --- a/matrix/server.py +++ b/matrix/server.py @@ -48,6 +48,7 @@ class MatrixServer: self.buffers = dict() # type: Dict[str, weechat.buffer] self.server_buffer = None # type: weechat.buffer self.fd_hook = None # type: weechat.hook + self.ssl_hook = None # type: weechat.hook self.timer_hook = None # type: weechat.hook self.numeric_address = "" # type: str diff --git a/matrix/utils.py b/matrix/utils.py index 62b83c9..749600e 100644 --- a/matrix/utils.py +++ b/matrix/utils.py @@ -93,7 +93,7 @@ def server_buffer_set_title(server): else: ip_string = "" - title = ("Matrix: {address}/{port}{ip}").format( + title = ("Matrix: {address}:{port}{ip}").format( address=server.address, port=server.port, ip=ip_string)