diff --git a/matrix/api.py b/matrix/api.py index 8d2799c..be59762 100644 --- a/matrix/api.py +++ b/matrix/api.py @@ -21,6 +21,11 @@ import time import json from enum import Enum, unique +try: + from urllib import quote, urlencode +except ImportError: + from urllib.parse import quote, urlencode + from matrix.globals import OPTIONS from matrix.http import RequestType, HttpRequest @@ -41,6 +46,56 @@ class MessageType(Enum): INVITE = 8 +class MatrixClient: + def __init__( + self, + host, # type: str + access_token="", # type: str + user_agent="" # type: str + ): + self.host = host + self.user_agent = user_agent + self.access_token = access_token + self.txn_id = 0 # type: int + + def login(self, user, password, device_name=""): + # type () -> HttpRequest + path = ("{api}/login").format(api=MATRIX_API_PATH) + + post_data = { + "type": "m.login.password", + "user": user, + "password": password + } + + if device_name: + post_data["initial_device_display_name"] = device_name + + return HttpRequest(RequestType.POST, self.host, path, post_data) + + def sync(self, next_batch="", sync_filter=None): + # type: (str, Dict[Any, Any]) -> HttpRequest + assert self.access_token + + query_parameters = {"access_token": self.access_token} + + if sync_filter: + query_parameters["filter"] = json.dumps( + sync_filter, + separators=(",", ":") + ) + + if next_batch: + query_parameters["since"] = next_batch + + path = ("{api}/sync?{query_params}").format( + api=MATRIX_API_PATH, + query_params=urlencode(query_parameters) + ) + + return HttpRequest(RequestType.GET, self.host, path) + + class MatrixMessage: def __init__( self, @@ -63,14 +118,13 @@ class MatrixMessage: self.send_time = None # type: float self.receive_time = None # type: float + host = ':'.join([server.address, str(server.port)]) + if message_type == MessageType.LOGIN: - path = ("{api}/login").format(api=MATRIX_API_PATH) - self.request = HttpRequest( - RequestType.POST, - server.address, - server.port, - path, - data + self.request = server.client.login( + server.user, + server.password, + server.device_name ) elif message_type == MessageType.SYNC: @@ -80,23 +134,7 @@ class MatrixMessage: } } - path = ("{api}/sync?access_token={access_token}&" - "filter={sync_filter}").format( - api=MATRIX_API_PATH, - access_token=server.access_token, - sync_filter=json.dumps(sync_filter, - separators=(',', ':'))) - - if server.next_batch: - path = path + '&since={next_batch}'.format( - next_batch=server.next_batch) - - self.request = HttpRequest( - RequestType.GET, - server.address, - server.port, - path - ) + self.request = server.client.sync(server.next_batch, sync_filter) elif message_type == MessageType.SEND: path = ("{api}/rooms/{room}/send/m.room.message/{tx_id}?" @@ -108,8 +146,7 @@ class MatrixMessage: self.request = HttpRequest( RequestType.PUT, - server.address, - server.port, + host, path, data ) @@ -124,8 +161,7 @@ class MatrixMessage: self.request = HttpRequest( RequestType.PUT, - server.address, - server.port, + host, path, data ) @@ -141,8 +177,7 @@ class MatrixMessage: self.request = HttpRequest( RequestType.PUT, - server.address, - server.port, + host, path, data ) @@ -158,8 +193,7 @@ class MatrixMessage: access_token=server.access_token) self.request = HttpRequest( RequestType.GET, - server.address, - server.port, + host, path, ) @@ -172,8 +206,7 @@ class MatrixMessage: self.request = HttpRequest( RequestType.POST, - server.address, - server.port, + host, path, data ) @@ -187,8 +220,7 @@ class MatrixMessage: self.request = HttpRequest( RequestType.POST, - server.address, - server.port, + host, path, data ) @@ -202,8 +234,7 @@ class MatrixMessage: self.request = HttpRequest( RequestType.POST, - server.address, - server.port, + host, path, data ) @@ -245,15 +276,9 @@ def matrix_sync(server): def matrix_login(server): # type: (MatrixServer) -> None - post_data = {"type": "m.login.password", - "user": server.user, - "password": server.password, - "initial_device_display_name": server.device_name} - message = MatrixMessage( server, OPTIONS, - MessageType.LOGIN, - data=post_data + MessageType.LOGIN ) server.send_or_queue(message) diff --git a/matrix/http.py b/matrix/http.py index 6b4d881..8638f3e 100644 --- a/matrix/http.py +++ b/matrix/http.py @@ -40,17 +40,14 @@ class HttpRequest: self, request_type, # type: RequestType host, # type: str - port, # type: int location, # type: str data=None, # type: Dict[str, Any] user_agent='weechat-matrix/{version}'.format( version="0.1") # type: str ): # type: (...) -> None - host_string = ':'.join([host, str(port)]) - user_agent = 'User-Agent: {agent}'.format(agent=user_agent) - host_header = 'Host: {host}'.format(host=host_string) + host_header = 'Host: {host}'.format(host=host) request_list = [] # type: List[str] accept_header = 'Accept: */*' # type: str end_separator = '\r\n' # type: str diff --git a/matrix/messages.py b/matrix/messages.py index e5db99c..ff1e375 100644 --- a/matrix/messages.py +++ b/matrix/messages.py @@ -686,8 +686,9 @@ def matrix_handle_message( if message_type is MessageType.LOGIN: server.access_token = response["access_token"] server.user_id = response["user_id"] - message = MatrixMessage(server, OPTIONS, MessageType.SYNC) - server.send_or_queue(message) + server.client.access_token = server.access_token + + matrix_sync(server) elif message_type is MessageType.SYNC: next_batch = response['next_batch'] diff --git a/matrix/server.py b/matrix/server.py index cf899f2..7ffedd7 100644 --- a/matrix/server.py +++ b/matrix/server.py @@ -34,6 +34,7 @@ from matrix.utils import ( ) from matrix.utf import utf8_decode from matrix.globals import W, SERVERS +from matrix.api import MatrixClient class MatrixServer: @@ -66,6 +67,7 @@ class MatrixServer: self.socket = None # type: ssl.SSLSocket self.ssl_context = ssl.create_default_context() # type: ssl.SSLContext + self.client = None self.access_token = None # type: str self.next_batch = None # type: str self.transaction_id = 0 # type: int @@ -143,16 +145,23 @@ class MatrixServer: self.http_parser = HttpParser() self.http_buffer = [] + def _change_client(self): + host = ':'.join([self.address, str(self.port)]) + user_agent = 'weechat-matrix/{version}'.format(version="0.1") + self.client = MatrixClient(host, user_agent=user_agent) + def update_option(self, option, option_name): if option_name == "address": value = W.config_string(option) self.address = value + self._change_client() elif option_name == "autoconnect": value = W.config_boolean(option) self.autoconnect = value elif option_name == "port": value = W.config_integer(option) self.port = value + self._change_client() elif option_name == "ssl_verify": value = W.config_boolean(option) if value: